训练代码优化

This commit is contained in:
lee
2025-07-03 15:16:58 +08:00
parent bcbabd9313
commit 6640f2bc5e

View File

@ -122,84 +122,66 @@ def log_training_info(log_path, log_info):
f.write(log_info + '\n') f.write(log_info + '\n')
def initialize_training_components(): def initialize_training_components(distributed=False):
"""初始化所有训练所需组件""" """初始化所有训练所需组件"""
# 加载配置 # 加载配置
conf = load_configuration() conf = load_configuration()
# 初始化分布式训练 # 初始化分布式训练相关参数
distributed = conf['base']['distributed'] components = {
if distributed:
dist.init_process_group(backend='nccl')
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)
else:
device = conf['base']['device']
# 数据加载
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
# 如果使用分布式,需要为每个进程创建单独的数据加载器
if distributed:
train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True)
val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False)
# 重新创建适合分布式训练的数据加载器
train_dataloader = torch.utils.data.DataLoader(
train_dataloader.dataset,
batch_size=train_dataloader.batch_size,
sampler=train_sampler,
num_workers=train_dataloader.num_workers,
pin_memory=train_dataloader.pin_memory,
drop_last=train_dataloader.drop_last
)
val_dataloader = torch.utils.data.DataLoader(
val_dataloader.dataset,
batch_size=val_dataloader.batch_size,
sampler=val_sampler,
num_workers=val_dataloader.num_workers,
pin_memory=val_dataloader.pin_memory,
drop_last=val_dataloader.drop_last
)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
model = model.to(device)
metric = metric.to(device)
if distributed:
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
metric = DDP(metric, device_ids=[local_rank], output_device=local_rank)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
return {
'conf': conf, 'conf': conf,
'train_dataloader': train_dataloader, 'distributed': distributed,
'val_dataloader': val_dataloader, 'device': None,
'model': model, 'train_dataloader': None,
'metric': metric, 'val_dataloader': None,
'criterion': criterion, 'model': None,
'optimizer': optimizer, 'metric': None,
'scheduler': scheduler, 'criterion': None,
'checkpoints': checkpoints, 'optimizer': None,
'scaler': scaler, 'scheduler': None,
'device': device, 'checkpoints': None,
'distributed': distributed 'scaler': None
} }
# 如果是非分布式训练,直接创建所有组件
if not distributed:
# 数据加载
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
device = conf['base']['device']
model = model.to(device)
metric = metric.to(device)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
# 更新组件字典
components.update({
'train_dataloader': train_dataloader,
'val_dataloader': val_dataloader,
'model': model,
'metric': metric,
'criterion': criterion,
'optimizer': optimizer,
'scheduler': scheduler,
'checkpoints': checkpoints,
'scaler': scaler,
'device': device
})
return components
def run_training_loop(components): def run_training_loop(components):
"""运行完整的训练循环""" """运行完整的训练循环"""
@ -262,9 +244,107 @@ def run_training_loop(components):
plt.savefig('loss/mobilenetv3Large_2250_0316.png') plt.savefig('loss/mobilenetv3Large_2250_0316.png')
if __name__ == '__main__': def main():
# 初始化训练组件 """主函数入口"""
components = initialize_training_components() # 加载配置
conf = load_configuration()
# 检查是否启用分布式训练
distributed = conf['base']['distributed']
if distributed:
# 分布式训练使用mp.spawn启动多个进程
world_size = torch.cuda.device_count()
mp.spawn(
run_training,
args=(world_size, conf),
nprocs=world_size,
join=True
)
else:
# 单机训练:直接运行训练流程
components = initialize_training_components(distributed=False)
run_training_loop(components)
def run_training(rank, world_size, conf):
"""实际执行训练的函数供mp.spawn调用"""
# 初始化分布式环境
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device('cuda', rank)
# 创建数据加载器和模型等组件(分布式情况下)
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
model = model.to(device)
metric = metric.to(device)
# 包装为DistributedDataParallel模型
model = DDP(model, device_ids=[rank], output_device=rank)
metric = DDP(metric, device_ids=[rank], output_device=rank)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
# 创建分布式数据加载器
train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True)
val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False)
# 重新创建适合分布式训练的数据加载器
train_dataloader = torch.utils.data.DataLoader(
train_dataloader.dataset,
batch_size=train_dataloader.batch_size,
sampler=train_sampler,
num_workers=train_dataloader.num_workers,
pin_memory=train_dataloader.pin_memory,
drop_last=train_dataloader.drop_last
)
val_dataloader = torch.utils.data.DataLoader(
val_dataloader.dataset,
batch_size=val_dataloader.batch_size,
sampler=val_sampler,
num_workers=val_dataloader.num_workers,
pin_memory=val_dataloader.pin_memory,
drop_last=val_dataloader.drop_last
)
# 构建组件字典
components = {
'conf': conf,
'train_dataloader': train_dataloader,
'val_dataloader': val_dataloader,
'model': model,
'metric': metric,
'criterion': criterion,
'optimizer': optimizer,
'scheduler': scheduler,
'checkpoints': checkpoints,
'scaler': scaler,
'device': device,
'distributed': True # 因为是在mp.spawn中运行
}
# 运行训练循环 # 运行训练循环
run_training_loop(components) run_training_loop(components)
if __name__ == '__main__':
main()