From ebba07d1ca2101900820513d5dac939e2441d94b Mon Sep 17 00:00:00 2001 From: lee <770918727@qq.com> Date: Thu, 7 Aug 2025 10:52:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9Dataloader=E6=8F=90=E5=8D=87?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E6=95=88=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/compare.yml | 6 ++-- tools/dataset.py | 54 +++++++++++++++++++++++----- train_compare.py | 86 ++++++++++++++++++++++++++------------------- 3 files changed, 98 insertions(+), 48 deletions(-) diff --git a/configs/compare.yml b/configs/compare.yml index c547a26..f8a0caf 100644 --- a/configs/compare.yml +++ b/configs/compare.yml @@ -15,7 +15,7 @@ base: # 模型配置 models: - backbone: 'resnet18' + backbone: 'resnet50' channel_ratio: 1.0 # 训练参数 @@ -31,7 +31,7 @@ training: weight_decay: 0.0005 # 权重衰减 scheduler: "step" # 学习率调度器(可选:cosine/cosine_warm/step/None) num_workers: 32 # 数据加载线程数 - checkpoints: "./checkpoints/resnet18_electornic_20250806/" # 模型保存目录 + checkpoints: "./checkpoints/resnet50_electornic_20250807/" # 模型保存目录 restore: false restore_model: "./checkpoints/resnet18_20250717_scale=0.75_nosub/best.pth" # 模型恢复路径 cosine_t_0: 10 # 初始周期长度 @@ -62,7 +62,7 @@ transform: # 日志与监控 logging: - logging_dir: "./logs/resnet18_scale=0.75_nosub_log" # 日志保存目录 + logging_dir: "./logs/resnet50_electornic_log" # 日志保存目录 tensorboard: true # 是否启用TensorBoard checkpoint_interval: 30 # 检查点保存间隔(epoch) diff --git a/tools/dataset.py b/tools/dataset.py index 4060886..e11a002 100644 --- a/tools/dataset.py +++ b/tools/dataset.py @@ -5,12 +5,14 @@ import torchvision.transforms as T # from config import config as conf import torch + def pad_to_square(img): w, h = img.size max_wh = max(w, h) padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom) return F.pad(img, padding, fill=0, padding_mode='constant') + def get_transform(cfg): train_transform = T.Compose([ T.Lambda(pad_to_square), # 补边 @@ -32,7 +34,8 @@ def get_transform(cfg): ]) return train_transform, test_transform -def load_data(training=True, cfg=None): + +def load_data(training=True, cfg=None, return_dataset=False): train_transform, test_transform = get_transform(cfg) if training: dataroot = cfg['data']['data_train_dir'] @@ -47,14 +50,49 @@ def load_data(training=True, cfg=None): data = ImageFolder(dataroot, transform=transform) class_num = len(data.classes) - loader = DataLoader(data, - batch_size=batch_size, - shuffle=True, - pin_memory=cfg['base']['pin_memory'], - num_workers=cfg['data']['num_workers'], - drop_last=True) - return loader, class_num + if return_dataset: + return data, class_num + else: + loader = DataLoader(data, + batch_size=batch_size, + shuffle=True if training else False, + pin_memory=cfg['base']['pin_memory'], + num_workers=cfg['data']['num_workers'], + drop_last=True) + return loader, class_num +class MultiEpochsDataLoader(torch.utils.data.DataLoader): + """ + MultiEpochsDataLoader 类 + 通过重用工作进程来提高数据加载效率,避免每个epoch重新启动工作进程 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._DataLoader__initialized = False + self.batch_sampler = _RepeatSampler(self.batch_sampler) + self._DataLoader__initialized = True + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler(object): + """ + 重复采样器,避免每个epoch重新创建迭代器 + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) # def load_gift_data(action): # train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform) # train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True, diff --git a/train_compare.py b/train_compare.py index e833207..b1f8208 100644 --- a/train_compare.py +++ b/train_compare.py @@ -10,7 +10,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from model.loss import FocalLoss -from tools.dataset import load_data +from tools.dataset import load_data, MultiEpochsDataLoader import matplotlib.pyplot as plt from configs import trainer_tools import yaml @@ -52,7 +52,7 @@ def setup_optimizer_and_scheduler(conf, model, metric): scheduler_mapping = tr_tools.get_scheduler(optimizer) scheduler = scheduler_mapping[conf['training']['scheduler']]() print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'], - conf['training']['scheduler'])) + conf['training']['scheduler'])) return optimizer, scheduler else: raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer'])) @@ -146,9 +146,21 @@ def initialize_training_components(distributed=False): # 如果是非分布式训练,直接创建所有组件 if not distributed: # 数据加载 - train_dataloader, class_num = load_data(training=True, cfg=conf) - val_dataloader, _ = load_data(training=False, cfg=conf) - + train_dataloader, class_num = load_data(training=True, cfg=conf, return_dataset=True) + val_dataloader, _ = load_data(training=False, cfg=conf, return_dataset=True) + + train_dataloader = MultiEpochsDataLoader(train_dataloader, + batch_size=conf['data']['train_batch_size'], + shuffle=True, + num_workers=conf['data']['num_workers'], + pin_memory=conf['base']['pin_memory'], + drop_last=True) + val_dataloader = MultiEpochsDataLoader(val_dataloader, + batch_size=conf['data']['val_batch_size'], + shuffle=False, + num_workers=conf['data']['num_workers'], + pin_memory=conf['base']['pin_memory'], + drop_last=False) # 初始化模型和度量 model, metric = initialize_model_and_metric(conf, class_num) device = conf['base']['device'] @@ -248,10 +260,10 @@ def main(): """主函数入口""" # 加载配置 conf = load_configuration() - + # 检查是否启用分布式训练 distributed = conf['base']['distributed'] - + if distributed: # 分布式训练:使用mp.spawn启动多个进程 world_size = torch.cuda.device_count() @@ -274,56 +286,56 @@ def run_training(rank, world_size, conf): os.environ['WORLD_SIZE'] = str(world_size) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' - + dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) torch.cuda.set_device(rank) device = torch.device('cuda', rank) - - # 创建数据加载器和模型等组件(分布式情况下) - train_dataloader, class_num = load_data(training=True, cfg=conf) - val_dataloader, _ = load_data(training=False, cfg=conf) - + + # 获取数据集而不是DataLoader + train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True) + val_dataset, _ = load_data(training=False, cfg=conf, return_dataset=True) + # 初始化模型和度量 model, metric = initialize_model_and_metric(conf, class_num) model = model.to(device) metric = metric.to(device) - + # 包装为DistributedDataParallel模型 model = DDP(model, device_ids=[rank], output_device=rank) metric = DDP(metric, device_ids=[rank], output_device=rank) - + # 设置损失函数、优化器和调度器 criterion = setup_loss_function(conf) optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric) - + # 检查点目录 checkpoints = conf['training']['checkpoints'] os.makedirs(checkpoints, exist_ok=True) - + # GradScaler for mixed precision scaler = torch.cuda.amp.GradScaler() - - # 创建分布式数据加载器 - train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True) - val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False) - - # 重新创建适合分布式训练的数据加载器 - train_dataloader = torch.utils.data.DataLoader( - train_dataloader.dataset, - batch_size=train_dataloader.batch_size, + + # 创建分布式采样器 + train_sampler = DistributedSampler(train_dataset, shuffle=True) + val_sampler = DistributedSampler(val_dataset, shuffle=False) + + # 使用 MultiEpochsDataLoader 创建分布式数据加载器 + train_dataloader = MultiEpochsDataLoader( + train_dataset, + batch_size=conf['data']['train_batch_size'], sampler=train_sampler, - num_workers=train_dataloader.num_workers, - pin_memory=train_dataloader.pin_memory, - drop_last=train_dataloader.drop_last + num_workers=conf['data']['num_workers'], + pin_memory=conf['base']['pin_memory'], + drop_last=True ) - - val_dataloader = torch.utils.data.DataLoader( - val_dataloader.dataset, - batch_size=val_dataloader.batch_size, + + val_dataloader = MultiEpochsDataLoader( + val_dataset, + batch_size=conf['data']['val_batch_size'], sampler=val_sampler, - num_workers=val_dataloader.num_workers, - pin_memory=val_dataloader.pin_memory, - drop_last=val_dataloader.drop_last + num_workers=conf['data']['num_workers'], + pin_memory=conf['base']['pin_memory'], + drop_last=False ) # 构建组件字典 @@ -341,7 +353,7 @@ def run_training(rank, world_size, conf): 'device': device, 'distributed': True # 因为是在mp.spawn中运行 } - + # 运行训练循环 run_training_loop(components)