Compare commits

..

5 Commits

Author SHA1 Message Date
lee
c978787ff8 多机并行计算 2025-08-18 10:14:05 +08:00
lee
99a204ee22 多机并行计算 2025-08-14 10:09:54 +08:00
lee
bc896fc688 修改Dataloader提升训练效率 2025-08-07 11:00:36 +08:00
lee
27ffb62223 修改Dataloader提升训练效率 2025-08-07 10:56:32 +08:00
lee
ebba07d1ca 修改Dataloader提升训练效率 2025-08-07 10:52:42 +08:00
20 changed files with 229 additions and 111 deletions

1
.gitignore vendored
View File

@ -8,4 +8,5 @@ loss/
checkpoints/ checkpoints/
search_library/ search_library/
quant_imgs/ quant_imgs/
electronic_imgs/
README.md README.md

View File

@ -3,6 +3,24 @@
<component name="CopilotChatHistory"> <component name="CopilotChatHistory">
<option name="conversations"> <option name="conversations">
<list> <list>
<Conversation>
<option name="createTime" value="1755228773977" />
<option name="id" value="0198abc99e597020bf8aa3ef78bc8bd3" />
<option name="title" value="新对话 2025年8月15日 11:32:53" />
<option name="updateTime" value="1755228773977" />
</Conversation>
<Conversation>
<option name="createTime" value="1755227620606" />
<option name="id" value="0198abb804fe7bf8ab3ac9ecfeae6d3f" />
<option name="title" value="新对话 2025年8月15日 11:13:40" />
<option name="updateTime" value="1755227620606" />
</Conversation>
<Conversation>
<option name="createTime" value="1755219481041" />
<option name="id" value="0198ab3bd1d17216b0dab33158ff294e" />
<option name="title" value="新对话 2025年8月15日 08:58:01" />
<option name="updateTime" value="1755219481041" />
</Conversation>
<Conversation> <Conversation>
<option name="createTime" value="1754286137102" /> <option name="createTime" value="1754286137102" />
<option name="id" value="0198739a1f0e75c38b0579ade7b34050" /> <option name="id" value="0198739a1f0e75c38b0579ade7b34050" />

View File

@ -16,7 +16,7 @@ base:
# 模型配置 # 模型配置
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 1.0 channel_ratio: 0.75
# 训练参数 # 训练参数
training: training:
@ -31,9 +31,9 @@ training:
weight_decay: 0.0005 # 权重衰减 weight_decay: 0.0005 # 权重衰减
scheduler: "step" # 学习率调度器可选cosine/cosine_warm/step/None scheduler: "step" # 学习率调度器可选cosine/cosine_warm/step/None
num_workers: 32 # 数据加载线程数 num_workers: 32 # 数据加载线程数
checkpoints: "./checkpoints/resnet18_electornic_20250806/" # 模型保存目录 checkpoints: "./checkpoints/resnet18_pdd_test/" # 模型保存目录
restore: false restore: false
restore_model: "./checkpoints/resnet18_20250717_scale=0.75_nosub/best.pth" # 模型恢复路径 restore_model: "./checkpoints/resnet50_electornic_20250807/best.pth" # 模型恢复路径
cosine_t_0: 10 # 初始周期长度 cosine_t_0: 10 # 初始周期长度
cosine_t_mult: 1 # 周期长度倍率 cosine_t_mult: 1 # 周期长度倍率
cosine_eta_min: 0.00001 # 最小学习率 cosine_eta_min: 0.00001 # 最小学习率
@ -62,11 +62,13 @@ transform:
# 日志与监控 # 日志与监控
logging: logging:
logging_dir: "./logs/resnet18_scale=0.75_nosub_log" # 日志保存目录 logging_dir: "./logs/resnet50_electornic_log" # 日志保存目录
tensorboard: true # 是否启用TensorBoard tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch checkpoint_interval: 30 # 检查点保存间隔epoch
# 分布式训练(可选) # 分布式训练(可选)
distributed: distributed:
enabled: false # 是否启用分布式训练 enabled: true # 是否启用分布式训练
backend: "nccl" # 分布式后端nccl/gloo backend: "nccl" # 分布式后端nccl/gloo
node_rank: 0 # 节点编号
node_num: 2 # 共计几个节点 一般几台机器就有几个节点

View File

@ -15,7 +15,8 @@ base:
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 0.75 channel_ratio: 0.75
model_path: "../checkpoints/resnet18_1009/best.pth" model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
# model_path: "../checkpoints/resnet18_1009/best.pth"
heatmap: heatmap:
feature_layer: "layer4" feature_layer: "layer4"
@ -26,10 +27,10 @@ data:
train_batch_size: 128 # 训练批次大小 train_batch_size: 128 # 训练批次大小
val_batch_size: 8 # 验证批次大小 val_batch_size: 8 # 验证批次大小
num_workers: 32 # 数据加载线程数 num_workers: 32 # 数据加载线程数
data_dir: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix" data_dir: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new"
image_joint_pth: "/home/lc/data_center/image_analysis/error_compare_result" image_joint_pth: "/home/lc/data_center/image_analysis/error_compare_result"
total_pkl: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix/total.pkl" total_pkl: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new/total.pkl"
result_txt: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix/result.txt" result_txt: "/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new/result.txt"
transform: transform:
img_size: 224 # 图像尺寸 img_size: 224 # 图像尺寸
@ -45,9 +46,9 @@ logging:
tensorboard: true # 是否启用TensorBoard tensorboard: true # 是否启用TensorBoard
checkpoint_interval: 30 # 检查点保存间隔epoch checkpoint_interval: 30 # 检查点保存间隔epoch
event: #event:
oneToOne_max_th: 0.9 # oneToOne_max_th: 0.9
oneToSn_min_th: 0.6 # oneToSn_min_th: 0.6
event_save_dir: "/home/lc/works/realtime_yolov10s/online_yolov10s_resnetv11_20250702/yolos_tracking" # event_save_dir: "/home/lc/works/realtime_yolov10s/online_yolov10s_resnetv11_20250702/yolos_tracking"
stdlib_image_path: "/testDataAndLogs/module_test_record/comparison/标准图测试数据/pic/stlib_base" # stdlib_image_path: "/testDataAndLogs/module_test_record/comparison/标准图测试数据/pic/stlib_base"
pickle_path: "event.pickle" # pickle_path: "event.pickle"

View File

@ -15,7 +15,8 @@ base:
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 0.75 channel_ratio: 0.75
model_path: "../checkpoints/resnet18_1009/best.pth" # model_path: "../checkpoints/resnet18_1009/best.pth"
model_path: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth"
heatmap: heatmap:
feature_layer: "layer4" feature_layer: "layer4"

View File

@ -13,8 +13,8 @@ base:
# 模型配置 # 模型配置
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 0.75 channel_ratio: 1.0
model_path: "checkpoints/resnet18_1009/best.pth" model_path: "checkpoints/resnet18_electornic_20250806/best.pth"
#resnet18_20250715_scale=0.75_sub #resnet18_20250715_scale=0.75_sub
#resnet18_20250718_scale=0.75_nosub #resnet18_20250718_scale=0.75_nosub
half: false # 是否启用半精度测试fp16 half: false # 是否启用半精度测试fp16
@ -24,11 +24,11 @@ models:
data: data:
test_batch_size: 128 # 训练批次大小 test_batch_size: 128 # 训练批次大小
num_workers: 32 # 数据加载线程数 num_workers: 32 # 数据加载线程数
test_dir: "../data_center/contrast_data/v1/extra" # 验证数据集根目录 test_dir: "../data_center/electornic/v1/val" # 验证数据集根目录
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json" test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
test_list: "../data_center/contrast_data/v1/extra_cross_same.txt" test_list: "../data_center/electornic/v1/cross_same.txt"
group_test: false group_test: false
save_image_joint: true save_image_joint: false
image_joint_pth: "./joint_images" image_joint_pth: "./joint_images"
transform: transform:

View File

@ -10,15 +10,16 @@ base:
embedding_size: 256 # 特征维度 embedding_size: 256 # 特征维度
pin_memory: true # 是否启用pin_memory pin_memory: true # 是否启用pin_memory
distributed: true # 是否启用分布式训练 distributed: true # 是否启用分布式训练
dataset: "./dataset_electornic.txt" # 数据集名称
# 模型配置 # 模型配置
models: models:
backbone: 'resnet18' backbone: 'resnet101'
channel_ratio: 1.0 channel_ratio: 1.0
model_path: "../checkpoints/resnet18_1009/best.pth" model_path: "../checkpoints/resnet101_electornic_20250807/best.pth"
onnx_model: "../checkpoints/resnet18_3399_sancheng/best.onnx" onnx_model: "../checkpoints/resnet101_electornic_20250807/best.onnx"
rknn_model: "../checkpoints/resnet18_3399_sancheng/best_rknn2.3.2_RK3566.rknn" rknn_model: "../checkpoints/resnet101_electornic_20250807/resnet101_electornic_3588.rknn"
rknn_batch_size: 1 rknn_batch_size: 1
# 日志与监控 # 日志与监控

View File

@ -1,4 +1,4 @@
from model import (resnet18, resnet34, resnet50, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1, from model import (resnet18, resnet34, resnet50, resnet101, mobilevit_s, MobileNetV3_Small, MobileNetV3_Large, mobilenet_v1,
PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5) PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
from timm.models import vit_base_patch16_224 as vit_base_16 from timm.models import vit_base_patch16_224 as vit_base_16
from model.metric import ArcFace, CosFace from model.metric import ArcFace, CosFace
@ -13,9 +13,10 @@ class trainer_tools:
def get_backbone(self): def get_backbone(self):
backbone_mapping = { backbone_mapping = {
'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio']), 'resnet18': lambda: resnet18(scale=self.conf['models']['channel_ratio'], pretrained=True),
'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio']), 'resnet34': lambda: resnet34(scale=self.conf['models']['channel_ratio'], pretrained=True),
'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio']), 'resnet50': lambda: resnet50(scale=self.conf['models']['channel_ratio'], pretrained=True),
'resnet101': lambda: resnet101(scale=self.conf['models']['channel_ratio'], pretrained=True),
'mobilevit_s': lambda: mobilevit_s(), 'mobilevit_s': lambda: mobilevit_s(),
'mobilenetv3_small': lambda: MobileNetV3_Small(), 'mobilenetv3_small': lambda: MobileNetV3_Small(),
'PPLCNET_x1_0': lambda: PPLCNET_x1_0(), 'PPLCNET_x1_0': lambda: PPLCNET_x1_0(),

View File

@ -14,7 +14,7 @@ base:
models: models:
backbone: 'resnet18' backbone: 'resnet18'
channel_ratio: 0.75 channel_ratio: 0.75
checkpoints: "../checkpoints/resnet18_20250715_scale=0.75_sub/best.pth" checkpoints: "../checkpoints/resnet18_20250718_scale=0.75_nosub/best.pth"
# 数据配置 # 数据配置
data: data:
@ -41,8 +41,8 @@ logging:
checkpoint_interval: 30 # 检查点保存间隔epoch checkpoint_interval: 30 # 检查点保存间隔epoch
save: save:
json_bin: "../search_library/yunhedian_05-09.json" # 保存整个json文件 json_bin: "../search_library/resnet101_electronic.json" # 保存整个json文件
json_path: "/home/lc/data_center/baseStlib/feature_json/stlib_base_resnet18_sub" # 保存单个json文件路径 json_path: "/home/lc/data_center/baseStlib/feature_json/stlib_base_resnet18_sub_1k_合并" # 保存单个json文件路径
error_barcodes: "error_barcodes.txt" error_barcodes: "error_barcodes.txt"
barcodes_statistics: "../search_library/barcodes_statistics.txt" barcodes_statistics: "../search_library/barcodes_statistics.txt"
create_single_json: true # 是否保存单个json文件 create_single_json: true # 是否保存单个json文件

View File

@ -4,7 +4,7 @@ from .mobilevit import mobilevit_s
from .metric import ArcFace, CosFace from .metric import ArcFace, CosFace
from .loss import FocalLoss from .loss import FocalLoss
from .resbam import resnet from .resbam import resnet
from .resnet_pre import resnet18, resnet34, resnet50, resnet14, CustomResNet18 from .resnet_pre import resnet18, resnet34, resnet50, resnet101, resnet152,resnet14, CustomResNet18
from .mobilenet_v2 import mobilenet_v2 from .mobilenet_v2 import mobilenet_v2
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
# from .mobilenet_v1 import mobilenet_v1 # from .mobilenet_v1 import mobilenet_v1

View File

@ -368,7 +368,7 @@ def resnet18(pretrained=True, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet34(pretrained=False, progress=True, **kwargs): def resnet34(pretrained=True, progress=True, **kwargs):
r"""ResNet-34 model from r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
@ -380,7 +380,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet50(pretrained=False, progress=True, **kwargs): def resnet50(pretrained=True, progress=True, **kwargs):
r"""ResNet-50 model from r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
@ -392,7 +392,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet101(pretrained=False, progress=True, **kwargs): def resnet101(pretrained=True, progress=True, **kwargs):
r"""ResNet-101 model from r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

View File

@ -17,7 +17,7 @@ from configs import trainer_tools
import yaml import yaml
from datetime import datetime from datetime import datetime
with open('./configs/test.yml', 'r') as f: with open('../configs/test.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader) conf = yaml.load(f, Loader=yaml.FullLoader)
# Constants from config # Constants from config

View File

@ -5,12 +5,14 @@ import torchvision.transforms as T
# from config import config as conf # from config import config as conf
import torch import torch
def pad_to_square(img): def pad_to_square(img):
w, h = img.size w, h = img.size
max_wh = max(w, h) max_wh = max(w, h)
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom) padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
return F.pad(img, padding, fill=0, padding_mode='constant') return F.pad(img, padding, fill=0, padding_mode='constant')
def get_transform(cfg): def get_transform(cfg):
train_transform = T.Compose([ train_transform = T.Compose([
T.Lambda(pad_to_square), # 补边 T.Lambda(pad_to_square), # 补边
@ -32,7 +34,8 @@ def get_transform(cfg):
]) ])
return train_transform, test_transform return train_transform, test_transform
def load_data(training=True, cfg=None):
def load_data(training=True, cfg=None, return_dataset=False):
train_transform, test_transform = get_transform(cfg) train_transform, test_transform = get_transform(cfg)
if training: if training:
dataroot = cfg['data']['data_train_dir'] dataroot = cfg['data']['data_train_dir']
@ -47,14 +50,49 @@ def load_data(training=True, cfg=None):
data = ImageFolder(dataroot, transform=transform) data = ImageFolder(dataroot, transform=transform)
class_num = len(data.classes) class_num = len(data.classes)
if return_dataset:
return data, class_num
else:
loader = DataLoader(data, loader = DataLoader(data,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True if training else False,
pin_memory=cfg['base']['pin_memory'], pin_memory=cfg['base']['pin_memory'],
num_workers=cfg['data']['num_workers'], num_workers=cfg['data']['num_workers'],
drop_last=True) drop_last=True)
return loader, class_num return loader, class_num
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
"""
MultiEpochsDataLoader 类
通过重用工作进程来提高数据加载效率避免每个epoch重新启动工作进程
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
"""
重复采样器避免每个epoch重新创建迭代器
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
# def load_gift_data(action): # def load_gift_data(action):
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform) # train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True, # train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,

View File

@ -0,0 +1,23 @@
../electronic_imgs/0.jpg
../electronic_imgs/1.jpg
../electronic_imgs/2.jpg
../electronic_imgs/3.jpg
../electronic_imgs/4.jpg
../electronic_imgs/5.jpg
../electronic_imgs/6.jpg
../electronic_imgs/7.jpg
../electronic_imgs/8.jpg
../electronic_imgs/9.jpg
../electronic_imgs/10.jpg
../electronic_imgs/11.jpg
../electronic_imgs/12.jpg
../electronic_imgs/13.jpg
../electronic_imgs/14.jpg
../electronic_imgs/15.jpg
../electronic_imgs/16.jpg
../electronic_imgs/17.jpg
../electronic_imgs/18.jpg
../electronic_imgs/19.jpg
../electronic_imgs/20.jpg
../electronic_imgs/21.jpg
../electronic_imgs/22.jpg

View File

@ -203,11 +203,11 @@ class PairGenerator:
if __name__ == "__main__": if __name__ == "__main__":
original_path = '/home/lc/data_center/contrast_data/v1/extra' original_path = '/home/lc/data_center/electornic/v1/val'
parent_dir = str(Path(original_path).parent) parent_dir = str(Path(original_path).parent)
generator = PairGenerator(original_path) generator = PairGenerator(original_path)
# Example usage: # Example usage:
pairs = generator.get_pairs(original_path, pairs = generator.get_pairs(original_path,
output_txt=os.sep.join([parent_dir, 'extra_cross_same.txt'])) # Individual pairs output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs # groups = generator.get_group_pairs('val') # Group pairs

View File

@ -96,7 +96,7 @@ if __name__ == '__main__':
rknn.config( rknn.config(
mean_values=[[127.5, 127.5, 127.5]], mean_values=[[127.5, 127.5, 127.5]],
std_values=[[127.5, 127.5, 127.5]], std_values=[[127.5, 127.5, 127.5]],
target_platform='rk3566', target_platform='rk3588',
model_pruning=False, model_pruning=False,
compress_weight=False, compress_weight=False,
single_core_mode=True, single_core_mode=True,
@ -122,8 +122,9 @@ if __name__ == '__main__':
# Build model # Build model
print('--> Building model') print('--> Building model')
ret = rknn.build(do_quantization=False, # True ret = rknn.build(do_quantization=True, # True
dataset='./dataset.txt', # dataset='./dataset.txt',
dataset=conf['base']['dataset'],
rknn_batch_size=conf['models']['rknn_batch_size']) rknn_batch_size=conf['models']['rknn_batch_size'])
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt') # ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
if ret != 0: if ret != 0:

View File

@ -237,6 +237,6 @@ def get_histogram(data, label=None):
if __name__ == '__main__': if __name__ == '__main__':
# picTopic_matrix = picDirSimilarAnalysis() picTopic_matrix = picDirSimilarAnalysis()
# picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix') picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix_new')
read_result_txt() # read_result_txt()

View File

@ -22,7 +22,7 @@ class SimilarAnalysis:
"""初始化模型和度量方法""" """初始化模型和度量方法"""
tr_tools = trainer_tools(conf) tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone() backbone_mapping = tr_tools.get_backbone()
print('model_path {}'.format(conf['models']['model_path']))
if conf['models']['backbone'] in backbone_mapping: if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]() model = backbone_mapping[conf['models']['backbone']]()
else: else:

View File

@ -4,7 +4,7 @@ import logging
import numpy as np import numpy as np
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from tools.dataset import get_transform from tools.dataset import get_transform
from model import resnet18 from model import resnet18, resnet34, resnet50, resnet101
import torch import torch
from PIL import Image from PIL import Image
import pandas as pd import pandas as pd
@ -50,7 +50,16 @@ class FeatureExtractor:
raise FileNotFoundError(f"Model weights file not found: {model_path}") raise FileNotFoundError(f"Model weights file not found: {model_path}")
# Initialize model # Initialize model
if conf['models']['backbone'] == 'resnet18':
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device']) model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet34':
model = resnet34(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet50':
model = resnet50(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet101':
model = resnet101(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
else:
print("不支持的模型: {}".format(conf['models']['backbone']))
# Handle multi-GPU case # Handle multi-GPU case
if conf['base']['distributed']: if conf['base']['distributed']:
@ -168,7 +177,7 @@ class FeatureExtractor:
# Validate input directory # Validate input directory
if not os.path.isdir(folder): if not os.path.isdir(folder):
raise ValueError(f"Invalid directory: {folder}") raise ValueError(f"Invalid directory: {folder}")
i = 0
# Process each barcode directory # Process each barcode directory
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"): for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
if not dirs: # Leaf directory (contains images) if not dirs: # Leaf directory (contains images)
@ -180,14 +189,16 @@ class FeatureExtractor:
ori_barcode = basename ori_barcode = basename
barcode = basename barcode = basename
# Apply filter if provided # Apply filter if provided
i += 1
print(ori_barcode, i)
if filter and ori_barcode not in filter: if filter and ori_barcode not in filter:
continue continue
elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # barcode筛选长度
logger.warning(f"Skipping invalid barcode {ori_barcode}") # logger.warning(f"Skipping invalid barcode {ori_barcode}")
with open(conf['save']['error_barcodes'], 'a') as f: # with open(conf['save']['error_barcodes'], 'a') as f:
f.write(ori_barcode + '\n') # f.write(ori_barcode + '\n')
f.close() # f.close()
continue # continue
# Process image files # Process image files
if files: if files:
@ -299,7 +310,8 @@ class FeatureExtractor:
dicts['value'] = truncated_imgs_list dicts['value'] = truncated_imgs_list
if create_single_json: if create_single_json:
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json') # json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json') json_path = os.path.join(self.conf['save']['json_path'],
str(barcode_list[i]) + '.json')
with open(json_path, 'w') as json_file: with open(json_path, 'w') as json_file:
json.dump(dicts, json_file) json.dump(dicts, json_file)
else: else:
@ -317,8 +329,10 @@ class FeatureExtractor:
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f: with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
for barcode in os.listdir(pth): for barcode in os.listdir(pth):
print("barcode length >> {}".format(len(barcode))) print("barcode length >> {}".format(len(barcode)))
if len(barcode) > 13 or len(barcode) < 8:
continue # if len(barcode) > 13 or len(barcode) < 8: # barcode筛选长度
# continue
if filter is not None: if filter is not None:
f.writelines(barcode + '\n') f.writelines(barcode + '\n')
if barcode in filter: if barcode in filter:

View File

@ -10,7 +10,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from model.loss import FocalLoss from model.loss import FocalLoss
from tools.dataset import load_data from tools.dataset import load_data, MultiEpochsDataLoader
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from configs import trainer_tools from configs import trainer_tools
import yaml import yaml
@ -146,9 +146,21 @@ def initialize_training_components(distributed=False):
# 如果是非分布式训练,直接创建所有组件 # 如果是非分布式训练,直接创建所有组件
if not distributed: if not distributed:
# 数据加载 # 数据加载
train_dataloader, class_num = load_data(training=True, cfg=conf) train_dataloader, class_num = load_data(training=True, cfg=conf, return_dataset=True)
val_dataloader, _ = load_data(training=False, cfg=conf) val_dataloader, _ = load_data(training=False, cfg=conf, return_dataset=True)
train_dataloader = MultiEpochsDataLoader(train_dataloader,
batch_size=conf['data']['train_batch_size'],
shuffle=True,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=True)
val_dataloader = MultiEpochsDataLoader(val_dataloader,
batch_size=conf['data']['val_batch_size'],
shuffle=False,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=False)
# 初始化模型和度量 # 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num) model, metric = initialize_model_and_metric(conf, class_num)
device = conf['base']['device'] device = conf['base']['device']
@ -254,11 +266,15 @@ def main():
if distributed: if distributed:
# 分布式训练使用mp.spawn启动多个进程 # 分布式训练使用mp.spawn启动多个进程
world_size = torch.cuda.device_count() local_size = torch.cuda.device_count()
world_size = int(conf['distributed']['node_num'])*local_size
mp.spawn( mp.spawn(
run_training, run_training,
args=(world_size, conf), args=(conf['distributed']['node_rank'],
nprocs=world_size, local_size,
world_size,
conf),
nprocs=local_size,
join=True join=True
) )
else: else:
@ -267,21 +283,22 @@ def main():
run_training_loop(components) run_training_loop(components)
def run_training(rank, world_size, conf): def run_training(local_rank, node_rank, local_size, world_size, conf):
"""实际执行训练的函数供mp.spawn调用""" """实际执行训练的函数供mp.spawn调用"""
# 初始化分布式环境 # 初始化分布式环境
rank = local_rank + node_rank * local_size
os.environ['RANK'] = str(rank) os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size) os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank) torch.cuda.set_device(local_rank)
device = torch.device('cuda', rank) device = torch.device('cuda', local_rank)
# 创建数据加载器和模型等组件(分布式情况下) # 获取数据集而不是DataLoader
train_dataloader, class_num = load_data(training=True, cfg=conf) train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
val_dataloader, _ = load_data(training=False, cfg=conf) val_dataset, _ = load_data(training=False, cfg=conf, return_dataset=True)
# 初始化模型和度量 # 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num) model, metric = initialize_model_and_metric(conf, class_num)
@ -289,8 +306,8 @@ def run_training(rank, world_size, conf):
metric = metric.to(device) metric = metric.to(device)
# 包装为DistributedDataParallel模型 # 包装为DistributedDataParallel模型
model = DDP(model, device_ids=[rank], output_device=rank) model = DDP(model, device_ids=[local_rank], output_device=local_rank)
metric = DDP(metric, device_ids=[rank], output_device=rank) metric = DDP(metric, device_ids=[local_rank], output_device=local_rank)
# 设置损失函数、优化器和调度器 # 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf) criterion = setup_loss_function(conf)
@ -303,27 +320,27 @@ def run_training(rank, world_size, conf):
# GradScaler for mixed precision # GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler() scaler = torch.cuda.amp.GradScaler()
# 创建分布式数据加载 # 创建分布式采样
train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True) train_sampler = DistributedSampler(train_dataset, shuffle=True)
val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False) val_sampler = DistributedSampler(val_dataset, shuffle=False)
# 重新创建适合分布式训练的数据加载器 # 使用 MultiEpochsDataLoader 创建分布式数据加载器
train_dataloader = torch.utils.data.DataLoader( train_dataloader = MultiEpochsDataLoader(
train_dataloader.dataset, train_dataset,
batch_size=train_dataloader.batch_size, batch_size=conf['data']['train_batch_size'],
sampler=train_sampler, sampler=train_sampler,
num_workers=train_dataloader.num_workers, num_workers=conf['data']['num_workers'],
pin_memory=train_dataloader.pin_memory, pin_memory=conf['base']['pin_memory'],
drop_last=train_dataloader.drop_last drop_last=True
) )
val_dataloader = torch.utils.data.DataLoader( val_dataloader = MultiEpochsDataLoader(
val_dataloader.dataset, val_dataset,
batch_size=val_dataloader.batch_size, batch_size=conf['data']['val_batch_size'],
sampler=val_sampler, sampler=val_sampler,
num_workers=val_dataloader.num_workers, num_workers=conf['data']['num_workers'],
pin_memory=val_dataloader.pin_memory, pin_memory=conf['base']['pin_memory'],
drop_last=val_dataloader.drop_last drop_last=False
) )
# 构建组件字典 # 构建组件字典