多机并行计算
This commit is contained in:
@ -15,8 +15,8 @@ base:
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet50'
|
||||
channel_ratio: 1.0
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
|
||||
# 训练参数
|
||||
training:
|
||||
@ -31,9 +31,9 @@ training:
|
||||
weight_decay: 0.0005 # 权重衰减
|
||||
scheduler: "step" # 学习率调度器(可选:cosine/cosine_warm/step/None)
|
||||
num_workers: 32 # 数据加载线程数
|
||||
checkpoints: "./checkpoints/resnet50_electornic_20250807/" # 模型保存目录
|
||||
checkpoints: "./checkpoints/resnet18_pdd_test/" # 模型保存目录
|
||||
restore: false
|
||||
restore_model: "./checkpoints/resnet18_20250717_scale=0.75_nosub/best.pth" # 模型恢复路径
|
||||
restore_model: "./checkpoints/resnet50_electornic_20250807/best.pth" # 模型恢复路径
|
||||
cosine_t_0: 10 # 初始周期长度
|
||||
cosine_t_mult: 1 # 周期长度倍率
|
||||
cosine_eta_min: 0.00001 # 最小学习率
|
||||
@ -70,3 +70,5 @@ logging:
|
||||
distributed:
|
||||
enabled: false # 是否启用分布式训练
|
||||
backend: "nccl" # 分布式后端(nccl/gloo)
|
||||
node_rank: 0 # 节点编号
|
||||
node_num: 1 # 共计几个节点 一般几台机器就有几个节点
|
||||
|
@ -15,7 +15,8 @@ base:
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||
model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
|
||||
# model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||
|
||||
heatmap:
|
||||
feature_layer: "layer4"
|
||||
|
@ -15,7 +15,8 @@ base:
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||
# model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||
model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
|
||||
|
||||
heatmap:
|
||||
feature_layer: "layer4"
|
||||
|
@ -13,8 +13,8 @@ base:
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
model_path: "checkpoints/resnet18_1009/best.pth"
|
||||
channel_ratio: 1.0
|
||||
model_path: "checkpoints/resnet18_electornic_20250806/best.pth"
|
||||
#resnet18_20250715_scale=0.75_sub
|
||||
#resnet18_20250718_scale=0.75_nosub
|
||||
half: false # 是否启用半精度测试(fp16)
|
||||
@ -24,11 +24,11 @@ models:
|
||||
data:
|
||||
test_batch_size: 128 # 训练批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
test_dir: "../data_center/contrast_data/v1/extra" # 验证数据集根目录
|
||||
test_dir: "../data_center/electornic/v1/val" # 验证数据集根目录
|
||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||
test_list: "../data_center/contrast_data/v1/extra_cross_same.txt"
|
||||
test_list: "../data_center/electornic/v1/cross_same.txt"
|
||||
group_test: false
|
||||
save_image_joint: true
|
||||
save_image_joint: false
|
||||
image_joint_pth: "./joint_images"
|
||||
|
||||
transform:
|
||||
|
@ -10,15 +10,16 @@ base:
|
||||
embedding_size: 256 # 特征维度
|
||||
pin_memory: true # 是否启用pin_memory
|
||||
distributed: true # 是否启用分布式训练
|
||||
dataset: "./dataset_electornic.txt" # 数据集名称
|
||||
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
backbone: 'resnet101'
|
||||
channel_ratio: 1.0
|
||||
model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||
onnx_model: "../checkpoints/resnet18_3399_sancheng/best.onnx"
|
||||
rknn_model: "../checkpoints/resnet18_3399_sancheng/best_rknn2.3.2_RK3566.rknn"
|
||||
model_path: "../checkpoints/resnet101_electornic_20250807/best.pth"
|
||||
onnx_model: "../checkpoints/resnet101_electornic_20250807/best.onnx"
|
||||
rknn_model: "../checkpoints/resnet101_electornic_20250807/resnet101_electornic.rknn"
|
||||
rknn_batch_size: 1
|
||||
|
||||
# 日志与监控
|
||||
|
@ -1,4 +1,4 @@
|
||||
from model import (resnet18, resnet34, resnet50, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||
from model import (resnet18, resnet34, resnet50, resnet101, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
|
||||
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||||
from timm.models import vit_base_patch16_224 as vit_base_16
|
||||
from model.metric import ArcFace, CosFace
|
||||
@ -13,9 +13,10 @@ class trainer_tools:
|
||||
|
||||
def get_backbone(self):
|
||||
backbone_mapping = {
|
||||
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']),
|
||||
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio']),
|
||||
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio']),
|
||||
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio'], pretrained=True),
|
||||
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio'], pretrained=True),
|
||||
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio'], pretrained=True),
|
||||
'resnet101': lambda: resnet101(scale=self.conf['models']['channel_ratio'], pretrained=True),
|
||||
'mobilevit_s': lambda: mobilevit_s(),
|
||||
'mobilenetv3_small': lambda: MobileNetV3_Small(),
|
||||
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),
|
||||
|
@ -12,9 +12,9 @@ base:
|
||||
|
||||
# 模型配置
|
||||
models:
|
||||
backbone: 'resnet18'
|
||||
channel_ratio: 0.75
|
||||
checkpoints: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
|
||||
backbone: 'resnet101'
|
||||
channel_ratio: 1.0
|
||||
checkpoints: "../checkpoints/resnet101_electornic_20250807/best.pth"
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
@ -22,7 +22,7 @@ data:
|
||||
test_batch_size: 128 # 验证批次大小
|
||||
num_workers: 32 # 数据加载线程数
|
||||
half: true # 是否启用半精度数据
|
||||
img_dirs_path: "/home/lc/data_center/baseStlib/pic/stlib_base" # base标准库图片存储路径
|
||||
img_dirs_path: "/shareData/completed_data/scatter_data/electronic_scale/base/total" # base标准库图片存储路径
|
||||
# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
|
||||
xlsx_pth: false # 过滤商品, 默认None不进行过滤
|
||||
|
||||
@ -41,8 +41,8 @@ logging:
|
||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||
|
||||
save:
|
||||
json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件
|
||||
json_bin: "../search_library/resnet101_electronic.json" # 保存整个json文件
|
||||
json_path: "/home/lc/data_center/baseStlib/feature_json/stlib_base_resnet18_sub" # 保存单个json文件路径
|
||||
error_barcodes: "error_barcodes.txt"
|
||||
barcodes_statistics: "../search_library/barcodes_statistics.txt"
|
||||
create_single_json: true # 是否保存单个json文件
|
||||
create_single_json: false # 是否保存单个json文件
|
Reference in New Issue
Block a user