多机并行计算

This commit is contained in:
lee
2025-08-14 10:09:54 +08:00
parent bc896fc688
commit 99a204ee22
18 changed files with 105 additions and 55 deletions

1
.gitignore vendored
View File

@ -8,4 +8,5 @@ loss/
checkpoints/ checkpoints/
search_library/ search_library/
quant_imgs/ quant_imgs/
electronic_imgs/
README.md README.md

View File

@ -15,8 +15,8 @@ base:
# 模型配置 # 模型配置
models: models:
backbone: 'resnet50' backbone: 'resnet18'
channel_ratio: 1.0 channel_ratio: 0.75
# 训练参数 # 训练参数
training: training:
@ -31,9 +31,9 @@ training:
weight_decay: 0.0005 # 权重衰减 weight_decay: 0.0005 # 权重衰减
scheduler: "step" # 学习率调度器可选cosine/cosine_warm/step/None scheduler: "step" # 学习率调度器可选cosine/cosine_warm/step/None
num_workers: 32 # 数据加载线程数 num_workers: 32 # 数据加载线程数
checkpoints: "./checkpoints/resnet50_electornic_20250807/" # 模型保存目录 checkpoints: "./checkpoints/resnet18_pdd_test/" # 模型保存目录
restore: false restore: false
restore_model: "./checkpoints/resnet18_20250717_scale=0.75_nosub/best.pth" # 模型恢复路径 restore_model: "./checkpoints/resnet50_electornic_20250807/best.pth" # 模型恢复路径
cosine_t_0: 10 # 初始周期长度 cosine_t_0: 10 # 初始周期长度
cosine_t_mult: 1 # 周期长度倍率 cosine_t_mult: 1 # 周期长度倍率
cosine_eta_min: 0.00001 # 最小学习率 cosine_eta_min: 0.00001 # 最小学习率
@ -70,3 +70,5 @@ logging:
distributed: distributed:
enabled: false # 是否启用分布式训练 enabled: false # 是否启用分布式训练
backend: "nccl" # 分布式后端nccl/gloo backend: "nccl" # 分布式后端nccl/gloo
node_rank: 0 # 节点编号
node_num: 1 # 共计几个节点 一般几台机器就有几个节点

View File

@ -15,7 +15,8 @@ base:
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 0.75 channel_ratio: 0.75
model_path: "../checkpoints/resnet18_1009/best.pth" model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
# model_path: "../checkpoints/resnet18_1009/best.pth"
heatmap: heatmap:
feature_layer: "layer4" feature_layer: "layer4"

View File

@ -15,7 +15,8 @@ base:
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 0.75 channel_ratio: 0.75
model_path: "../checkpoints/resnet18_1009/best.pth" # model_path: "../checkpoints/resnet18_1009/best.pth"
model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
heatmap: heatmap:
feature_layer: "layer4" feature_layer: "layer4"

View File

@ -13,8 +13,8 @@ base:
# 模型配置 # 模型配置
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 0.75 channel_ratio: 1.0
model_path: "checkpoints/resnet18_1009/best.pth" model_path: "checkpoints/resnet18_electornic_20250806/best.pth"
#resnet18_20250715_scale=0.75_sub #resnet18_20250715_scale=0.75_sub
#resnet18_20250718_scale=0.75_nosub #resnet18_20250718_scale=0.75_nosub
half: false # 是否启用半精度测试fp16 half: false # 是否启用半精度测试fp16
@ -24,11 +24,11 @@ models:
data: data:
test_batch_size: 128 # 训练批次大小 test_batch_size: 128 # 训练批次大小
num_workers: 32 # 数据加载线程数 num_workers: 32 # 数据加载线程数
test_dir: "../data_center/contrast_data/v1/extra" # 验证数据集根目录 test_dir: "../data_center/electornic/v1/val" # 验证数据集根目录
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
test_list: "../data_center/contrast_data/v1/extra_cross_same.txt" test_list: "../data_center/electornic/v1/cross_same.txt"
group_test: false group_test: false
save_image_joint: true save_image_joint: false
image_joint_pth: "./joint_images" image_joint_pth: "./joint_images"
transform: transform:

View File

@ -10,15 +10,16 @@ base:
embedding_size: 256 # 特征维度 embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练 distributed: true # 是否启用分布式训练
dataset: "./dataset_electornic.txt" # 数据集名称
# 模型配置 # 模型配置
models: models:
backbone: 'resnet18' backbone: 'resnet101'
channel_ratio: 1.0 channel_ratio: 1.0
model_path: "../checkpoints/resnet18_1009/best.pth" model_path: "../checkpoints/resnet101_electornic_20250807/best.pth"
onnx_model: "../checkpoints/resnet18_3399_sancheng/best.onnx" onnx_model: "../checkpoints/resnet101_electornic_20250807/best.onnx"
rknn_model: "../checkpoints/resnet18_3399_sancheng/best_rknn2.3.2_RK3566.rknn" rknn_model: "../checkpoints/resnet101_electornic_20250807/resnet101_electornic.rknn"
rknn_batch_size: 1 rknn_batch_size: 1
# 日志与监控 # 日志与监控

View File

@ -1,4 +1,4 @@
from model import (resnet18, resnet34, resnet50, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1, from model import (resnet18, resnet34, resnet50, resnet101, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5) PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
from timm.models import vit_base_patch16_224 as vit_base_16 from timm.models import vit_base_patch16_224 as vit_base_16
from model.metric import ArcFace, CosFace from model.metric import ArcFace, CosFace
@ -13,9 +13,10 @@ class trainer_tools:
def get_backbone(self): def get_backbone(self):
backbone_mapping = { backbone_mapping = {
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']), 'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio'], pretrained=True),
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio']), 'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio'], pretrained=True),
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio']), 'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio'], pretrained=True),
'resnet101': lambda: resnet101(scale=self.conf['models']['channel_ratio'], pretrained=True),
'mobilevit_s': lambda: mobilevit_s(), 'mobilevit_s': lambda: mobilevit_s(),
'mobilenetv3_small': lambda: MobileNetV3_Small(), 'mobilenetv3_small': lambda: MobileNetV3_Small(),
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(), 'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),

View File

@ -12,9 +12,9 @@ base:
# 模型配置 # 模型配置
models: models:
backbone: 'resnet18' backbone: 'resnet101'
channel_ratio: 0.75 channel_ratio: 1.0
checkpoints: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth" checkpoints: "../checkpoints/resnet101_electornic_20250807/best.pth"
# 数据配置 # 数据配置
data: data:
@ -22,7 +22,7 @@ data:
test_batch_size: 128 # 验证批次大小 test_batch_size: 128 # 验证批次大小
num_workers: 32 # 数据加载线程数 num_workers: 32 # 数据加载线程数
half: true # 是否启用半精度数据 half: true # 是否启用半精度数据
img_dirs_path: "/home/lc/data_center/baseStlib/pic/stlib_base" # base标准库图片存储路径 img_dirs_path: "/shareData/completed_data/scatter_data/electronic_scale/base/total" # base标准库图片存储路径
# img_dirs_path: "/home/lc/contrast_nettest/data/feature_json" # img_dirs_path: "/home/lc/contrast_nettest/data/feature_json"
xlsx_pth: false # 过滤商品, 默认None不进行过滤 xlsx_pth: false # 过滤商品, 默认None不进行过滤
@ -41,8 +41,8 @@ logging:
checkpoint_interval: 30 # 检查点保存间隔epoch checkpoint_interval: 30 # 检查点保存间隔epoch
save: save:
json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件 json_bin: "../search_library/resnet101_electronic.json" # 保存整个json文件
json_path: "/home/lc/data_center/baseStlib/feature_json/stlib_base_resnet18_sub" # 保存单个json文件路径 json_path: "/home/lc/data_center/baseStlib/feature_json/stlib_base_resnet18_sub" # 保存单个json文件路径
error_barcodes: "error_barcodes.txt" error_barcodes: "error_barcodes.txt"
barcodes_statistics: "../search_library/barcodes_statistics.txt" barcodes_statistics: "../search_library/barcodes_statistics.txt"
create_single_json: true # 是否保存单个json文件 create_single_json: false # 是否保存单个json文件

View File

@ -4,7 +4,7 @@ from .mobilevit import mobilevit_s
from .metric import ArcFace, CosFace from .metric import ArcFace, CosFace
from .loss import FocalLoss from .loss import FocalLoss
from .resbam import resnet from .resbam import resnet
from .resnet_pre import resnet18, resnet34, resnet50, resnet14, CustomResNet18 from .resnet_pre import resnet18, resnet34, resnet50, resnet101, resnet152,resnet14, CustomResNet18
from .mobilenet_v2 import mobilenet_v2 from .mobilenet_v2 import mobilenet_v2
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
# from .mobilenet_v1 import mobilenet_v1 # from .mobilenet_v1 import mobilenet_v1

View File

@ -368,7 +368,7 @@ def resnet18(pretrained=True, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet34(pretrained=False, progress=True, **kwargs): def resnet34(pretrained=True, progress=True, **kwargs):
r"""ResNet-34 model from r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
@ -380,7 +380,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet50(pretrained=False, progress=True, **kwargs): def resnet50(pretrained=True, progress=True, **kwargs):
r"""ResNet-50 model from r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
@ -392,7 +392,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet101(pretrained=False, progress=True, **kwargs): def resnet101(pretrained=True, progress=True, **kwargs):
r"""ResNet-101 model from r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

View File

@ -17,7 +17,7 @@ from configs import trainer_tools
import yaml import yaml
from datetime import datetime from datetime import datetime
with open('./configs/test.yml', 'r') as f: with open('../configs/test.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader) conf = yaml.load(f, Loader=yaml.FullLoader)
# Constants from config # Constants from config

View File

@ -0,0 +1,23 @@
../electronic_imgs/0.jpg
../electronic_imgs/1.jpg
../electronic_imgs/2.jpg
../electronic_imgs/3.jpg
../electronic_imgs/4.jpg
../electronic_imgs/5.jpg
../electronic_imgs/6.jpg
../electronic_imgs/7.jpg
../electronic_imgs/8.jpg
../electronic_imgs/9.jpg
../electronic_imgs/10.jpg
../electronic_imgs/11.jpg
../electronic_imgs/12.jpg
../electronic_imgs/13.jpg
../electronic_imgs/14.jpg
../electronic_imgs/15.jpg
../electronic_imgs/16.jpg
../electronic_imgs/17.jpg
../electronic_imgs/18.jpg
../electronic_imgs/19.jpg
../electronic_imgs/20.jpg
../electronic_imgs/21.jpg
../electronic_imgs/22.jpg

View File

@ -203,11 +203,11 @@ class PairGenerator:
if __name__ == "__main__": if __name__ == "__main__":
original_path = '/home/lc/data_center/contrast_data/v1/extra' original_path = '/home/lc/data_center/electornic/v1/val'
parent_dir = str(Path(original_path).parent) parent_dir = str(Path(original_path).parent)
generator = PairGenerator(original_path) generator = PairGenerator(original_path)
# Example usage: # Example usage:
pairs = generator.get_pairs(original_path, pairs = generator.get_pairs(original_path,
output_txt=os.sep.join([parent_dir, 'extra_cross_same.txt'])) # Individual pairs output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs # groups = generator.get_group_pairs('val') # Group pairs

View File

@ -122,8 +122,9 @@ if __name__ == '__main__':
# Build model # Build model
print('--> Building model') print('--> Building model')
ret = rknn.build(do_quantization=False, # True ret = rknn.build(do_quantization=True, # True
dataset='./dataset.txt', # dataset='./dataset.txt',
dataset=conf['base']['dataset'],
rknn_batch_size=conf['models']['rknn_batch_size']) rknn_batch_size=conf['models']['rknn_batch_size'])
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt') # ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
if ret != 0: if ret != 0:

View File

@ -237,6 +237,6 @@ def get_histogram(data, label=None):
if __name__ == '__main__': if __name__ == '__main__':
# picTopic_matrix = picDirSimilarAnalysis() picTopic_matrix = picDirSimilarAnalysis()
# picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix') picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix')
read_result_txt() # read_result_txt()

View File

@ -22,7 +22,7 @@ class SimilarAnalysis:
"""初始化模型和度量方法""" """初始化模型和度量方法"""
tr_tools = trainer_tools(conf) tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone() backbone_mapping = tr_tools.get_backbone()
print('model_path {}'.format(conf['models']['model_path']))
if conf['models']['backbone'] in backbone_mapping: if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]() model = backbone_mapping[conf['models']['backbone']]()
else: else:

View File

@ -4,7 +4,7 @@ import logging
import numpy as np import numpy as np
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from tools.dataset import get_transform from tools.dataset import get_transform
from model import resnet18 from model import resnet18, resnet34, resnet50, resnet101
import torch import torch
from PIL import Image from PIL import Image
import pandas as pd import pandas as pd
@ -50,7 +50,16 @@ class FeatureExtractor:
raise FileNotFoundError(f"Model weights file not found: {model_path}") raise FileNotFoundError(f"Model weights file not found: {model_path}")
# Initialize model # Initialize model
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device']) if conf['models']['backbone'] == 'resnet18':
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet34':
model = resnet34(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet50':
model = resnet50(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet101':
model = resnet101(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
else:
print("不支持的模型: {}".format(conf['models']['backbone']))
# Handle multi-GPU case # Handle multi-GPU case
if conf['base']['distributed']: if conf['base']['distributed']:
@ -168,7 +177,7 @@ class FeatureExtractor:
# Validate input directory # Validate input directory
if not os.path.isdir(folder): if not os.path.isdir(folder):
raise ValueError(f"Invalid directory: {folder}") raise ValueError(f"Invalid directory: {folder}")
i = 0
# Process each barcode directory # Process each barcode directory
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"): for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
if not dirs: # Leaf directory (contains images) if not dirs: # Leaf directory (contains images)
@ -180,14 +189,16 @@ class FeatureExtractor:
ori_barcode = basename ori_barcode = basename
barcode = basename barcode = basename
# Apply filter if provided # Apply filter if provided
i += 1
print(ori_barcode, i)
if filter and ori_barcode not in filter: if filter and ori_barcode not in filter:
continue continue
elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # barcode筛选长度
logger.warning(f"Skipping invalid barcode {ori_barcode}") # logger.warning(f"Skipping invalid barcode {ori_barcode}")
with open(conf['save']['error_barcodes'], 'a') as f: # with open(conf['save']['error_barcodes'], 'a') as f:
f.write(ori_barcode + '\n') # f.write(ori_barcode + '\n')
f.close() # f.close()
continue # continue
# Process image files # Process image files
if files: if files:
@ -299,7 +310,8 @@ class FeatureExtractor:
dicts['value'] = truncated_imgs_list dicts['value'] = truncated_imgs_list
if create_single_json: if create_single_json:
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json') # json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json') json_path = os.path.join(self.conf['save']['json_path'],
str(barcode_list[i]) + '.json')
with open(json_path, 'w') as json_file: with open(json_path, 'w') as json_file:
json.dump(dicts, json_file) json.dump(dicts, json_file)
else: else:
@ -317,8 +329,10 @@ class FeatureExtractor:
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f: with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
for barcode in os.listdir(pth): for barcode in os.listdir(pth):
print("barcode length >> {}".format(len(barcode))) print("barcode length >> {}".format(len(barcode)))
if len(barcode) > 13 or len(barcode) < 8:
continue # if len(barcode) > 13 or len(barcode) < 8: # barcode筛选长度
# continue
if filter is not None: if filter is not None:
f.writelines(barcode + '\n') f.writelines(barcode + '\n')
if barcode in filter: if barcode in filter:

View File

@ -266,11 +266,15 @@ def main():
if distributed: if distributed:
# 分布式训练使用mp.spawn启动多个进程 # 分布式训练使用mp.spawn启动多个进程
world_size = torch.cuda.device_count() local_size = torch.cuda.device_count()
world_size = int(conf['distributed']['node_num'])*local_size
mp.spawn( mp.spawn(
run_training, run_training,
args=(world_size, conf), args=(conf['distributed']['node_rank'],
nprocs=world_size, local_size,
world_size,
conf),
nprocs=local_size,
join=True join=True
) )
else: else:
@ -279,11 +283,12 @@ def main():
run_training_loop(components) run_training_loop(components)
def run_training(rank, world_size, conf): def run_training(local_rank, node_rank, local_size, world_size, conf):
"""实际执行训练的函数供mp.spawn调用""" """实际执行训练的函数供mp.spawn调用"""
# 初始化分布式环境 # 初始化分布式环境
rank = local_rank + node_rank * local_size
os.environ['RANK'] = str(rank) os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size) os.environ['WORLD_SIZE'] = str(local_size)
os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' os.environ['MASTER_PORT'] = '12355'