From 6640f2bc5e8cbe95a8e26b4ce903c2a6d21befca Mon Sep 17 00:00:00 2001 From: lee <770918727@qq.com> Date: Thu, 3 Jul 2025 15:16:58 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=AD=E7=BB=83=E4=BB=A3=E7=A0=81=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_compare.py | 230 +++++++++++++++++++++++++++++++---------------- 1 file changed, 155 insertions(+), 75 deletions(-) diff --git a/train_compare.py b/train_compare.py index 5a413ab..d3205ea 100644 --- a/train_compare.py +++ b/train_compare.py @@ -122,84 +122,66 @@ def log_training_info(log_path, log_info): f.write(log_info + '\n') -def initialize_training_components(): +def initialize_training_components(distributed=False): """初始化所有训练所需组件""" # 加载配置 conf = load_configuration() - # 初始化分布式训练 - distributed = conf['base']['distributed'] - if distributed: - dist.init_process_group(backend='nccl') - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - device = torch.device('cuda', local_rank) - else: - device = conf['base']['device'] - - # 数据加载 - train_dataloader, class_num = load_data(training=True, cfg=conf) - val_dataloader, _ = load_data(training=False, cfg=conf) - - # 如果使用分布式,需要为每个进程创建单独的数据加载器 - if distributed: - train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True) - val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False) - - # 重新创建适合分布式训练的数据加载器 - train_dataloader = torch.utils.data.DataLoader( - train_dataloader.dataset, - batch_size=train_dataloader.batch_size, - sampler=train_sampler, - num_workers=train_dataloader.num_workers, - pin_memory=train_dataloader.pin_memory, - drop_last=train_dataloader.drop_last - ) - - val_dataloader = torch.utils.data.DataLoader( - val_dataloader.dataset, - batch_size=val_dataloader.batch_size, - sampler=val_sampler, - num_workers=val_dataloader.num_workers, - pin_memory=val_dataloader.pin_memory, - drop_last=val_dataloader.drop_last - ) - - # 初始化模型和度量 - model, metric = initialize_model_and_metric(conf, class_num) - model = model.to(device) - metric = metric.to(device) - - if distributed: - model = DDP(model, device_ids=[local_rank], output_device=local_rank) - metric = DDP(metric, device_ids=[local_rank], output_device=local_rank) - - # 设置损失函数、优化器和调度器 - criterion = setup_loss_function(conf) - optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric) - - # 检查点目录 - checkpoints = conf['training']['checkpoints'] - os.makedirs(checkpoints, exist_ok=True) - - # GradScaler for mixed precision - scaler = torch.cuda.amp.GradScaler() - - return { + # 初始化分布式训练相关参数 + components = { 'conf': conf, - 'train_dataloader': train_dataloader, - 'val_dataloader': val_dataloader, - 'model': model, - 'metric': metric, - 'criterion': criterion, - 'optimizer': optimizer, - 'scheduler': scheduler, - 'checkpoints': checkpoints, - 'scaler': scaler, - 'device': device, - 'distributed': distributed + 'distributed': distributed, + 'device': None, + 'train_dataloader': None, + 'val_dataloader': None, + 'model': None, + 'metric': None, + 'criterion': None, + 'optimizer': None, + 'scheduler': None, + 'checkpoints': None, + 'scaler': None } + # 如果是非分布式训练,直接创建所有组件 + if not distributed: + # 数据加载 + train_dataloader, class_num = load_data(training=True, cfg=conf) + val_dataloader, _ = load_data(training=False, cfg=conf) + + # 初始化模型和度量 + model, metric = initialize_model_and_metric(conf, class_num) + device = conf['base']['device'] + model = model.to(device) + metric = metric.to(device) + + # 设置损失函数、优化器和调度器 + criterion = setup_loss_function(conf) + optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric) + + # 检查点目录 + checkpoints = conf['training']['checkpoints'] + os.makedirs(checkpoints, exist_ok=True) + + # GradScaler for mixed precision + scaler = torch.cuda.amp.GradScaler() + + # 更新组件字典 + components.update({ + 'train_dataloader': train_dataloader, + 'val_dataloader': val_dataloader, + 'model': model, + 'metric': metric, + 'criterion': criterion, + 'optimizer': optimizer, + 'scheduler': scheduler, + 'checkpoints': checkpoints, + 'scaler': scaler, + 'device': device + }) + + return components + def run_training_loop(components): """运行完整的训练循环""" @@ -262,9 +244,107 @@ def run_training_loop(components): plt.savefig('loss/mobilenetv3Large_2250_0316.png') -if __name__ == '__main__': - # 初始化训练组件 - components = initialize_training_components() +def main(): + """主函数入口""" + # 加载配置 + conf = load_configuration() + + # 检查是否启用分布式训练 + distributed = conf['base']['distributed'] + + if distributed: + # 分布式训练:使用mp.spawn启动多个进程 + world_size = torch.cuda.device_count() + mp.spawn( + run_training, + args=(world_size, conf), + nprocs=world_size, + join=True + ) + else: + # 单机训练:直接运行训练流程 + components = initialize_training_components(distributed=False) + run_training_loop(components) + + +def run_training(rank, world_size, conf): + """实际执行训练的函数,供mp.spawn调用""" + # 初始化分布式环境 + os.environ['RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + device = torch.device('cuda', rank) + + # 创建数据加载器和模型等组件(分布式情况下) + train_dataloader, class_num = load_data(training=True, cfg=conf) + val_dataloader, _ = load_data(training=False, cfg=conf) + + # 初始化模型和度量 + model, metric = initialize_model_and_metric(conf, class_num) + model = model.to(device) + metric = metric.to(device) + + # 包装为DistributedDataParallel模型 + model = DDP(model, device_ids=[rank], output_device=rank) + metric = DDP(metric, device_ids=[rank], output_device=rank) + + # 设置损失函数、优化器和调度器 + criterion = setup_loss_function(conf) + optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric) + + # 检查点目录 + checkpoints = conf['training']['checkpoints'] + os.makedirs(checkpoints, exist_ok=True) + + # GradScaler for mixed precision + scaler = torch.cuda.amp.GradScaler() + + # 创建分布式数据加载器 + train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True) + val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False) + + # 重新创建适合分布式训练的数据加载器 + train_dataloader = torch.utils.data.DataLoader( + train_dataloader.dataset, + batch_size=train_dataloader.batch_size, + sampler=train_sampler, + num_workers=train_dataloader.num_workers, + pin_memory=train_dataloader.pin_memory, + drop_last=train_dataloader.drop_last + ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataloader.dataset, + batch_size=val_dataloader.batch_size, + sampler=val_sampler, + num_workers=val_dataloader.num_workers, + pin_memory=val_dataloader.pin_memory, + drop_last=val_dataloader.drop_last + ) + + # 构建组件字典 + components = { + 'conf': conf, + 'train_dataloader': train_dataloader, + 'val_dataloader': val_dataloader, + 'model': model, + 'metric': metric, + 'criterion': criterion, + 'optimizer': optimizer, + 'scheduler': scheduler, + 'checkpoints': checkpoints, + 'scaler': scaler, + 'device': device, + 'distributed': True # 因为是在mp.spawn中运行 + } # 运行训练循环 - run_training_loop(components) \ No newline at end of file + run_training_loop(components) + + +if __name__ == '__main__': + main()