训练代码优化
This commit is contained in:
230
train_compare.py
230
train_compare.py
@ -122,84 +122,66 @@ def log_training_info(log_path, log_info):
|
|||||||
f.write(log_info + '\n')
|
f.write(log_info + '\n')
|
||||||
|
|
||||||
|
|
||||||
def initialize_training_components():
|
def initialize_training_components(distributed=False):
|
||||||
"""初始化所有训练所需组件"""
|
"""初始化所有训练所需组件"""
|
||||||
# 加载配置
|
# 加载配置
|
||||||
conf = load_configuration()
|
conf = load_configuration()
|
||||||
|
|
||||||
# 初始化分布式训练
|
# 初始化分布式训练相关参数
|
||||||
distributed = conf['base']['distributed']
|
components = {
|
||||||
if distributed:
|
|
||||||
dist.init_process_group(backend='nccl')
|
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
device = torch.device('cuda', local_rank)
|
|
||||||
else:
|
|
||||||
device = conf['base']['device']
|
|
||||||
|
|
||||||
# 数据加载
|
|
||||||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
|
||||||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
|
||||||
|
|
||||||
# 如果使用分布式,需要为每个进程创建单独的数据加载器
|
|
||||||
if distributed:
|
|
||||||
train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True)
|
|
||||||
val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False)
|
|
||||||
|
|
||||||
# 重新创建适合分布式训练的数据加载器
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
|
||||||
train_dataloader.dataset,
|
|
||||||
batch_size=train_dataloader.batch_size,
|
|
||||||
sampler=train_sampler,
|
|
||||||
num_workers=train_dataloader.num_workers,
|
|
||||||
pin_memory=train_dataloader.pin_memory,
|
|
||||||
drop_last=train_dataloader.drop_last
|
|
||||||
)
|
|
||||||
|
|
||||||
val_dataloader = torch.utils.data.DataLoader(
|
|
||||||
val_dataloader.dataset,
|
|
||||||
batch_size=val_dataloader.batch_size,
|
|
||||||
sampler=val_sampler,
|
|
||||||
num_workers=val_dataloader.num_workers,
|
|
||||||
pin_memory=val_dataloader.pin_memory,
|
|
||||||
drop_last=val_dataloader.drop_last
|
|
||||||
)
|
|
||||||
|
|
||||||
# 初始化模型和度量
|
|
||||||
model, metric = initialize_model_and_metric(conf, class_num)
|
|
||||||
model = model.to(device)
|
|
||||||
metric = metric.to(device)
|
|
||||||
|
|
||||||
if distributed:
|
|
||||||
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
|
||||||
metric = DDP(metric, device_ids=[local_rank], output_device=local_rank)
|
|
||||||
|
|
||||||
# 设置损失函数、优化器和调度器
|
|
||||||
criterion = setup_loss_function(conf)
|
|
||||||
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
|
||||||
|
|
||||||
# 检查点目录
|
|
||||||
checkpoints = conf['training']['checkpoints']
|
|
||||||
os.makedirs(checkpoints, exist_ok=True)
|
|
||||||
|
|
||||||
# GradScaler for mixed precision
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
|
||||||
|
|
||||||
return {
|
|
||||||
'conf': conf,
|
'conf': conf,
|
||||||
'train_dataloader': train_dataloader,
|
'distributed': distributed,
|
||||||
'val_dataloader': val_dataloader,
|
'device': None,
|
||||||
'model': model,
|
'train_dataloader': None,
|
||||||
'metric': metric,
|
'val_dataloader': None,
|
||||||
'criterion': criterion,
|
'model': None,
|
||||||
'optimizer': optimizer,
|
'metric': None,
|
||||||
'scheduler': scheduler,
|
'criterion': None,
|
||||||
'checkpoints': checkpoints,
|
'optimizer': None,
|
||||||
'scaler': scaler,
|
'scheduler': None,
|
||||||
'device': device,
|
'checkpoints': None,
|
||||||
'distributed': distributed
|
'scaler': None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 如果是非分布式训练,直接创建所有组件
|
||||||
|
if not distributed:
|
||||||
|
# 数据加载
|
||||||
|
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||||
|
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||||
|
|
||||||
|
# 初始化模型和度量
|
||||||
|
model, metric = initialize_model_and_metric(conf, class_num)
|
||||||
|
device = conf['base']['device']
|
||||||
|
model = model.to(device)
|
||||||
|
metric = metric.to(device)
|
||||||
|
|
||||||
|
# 设置损失函数、优化器和调度器
|
||||||
|
criterion = setup_loss_function(conf)
|
||||||
|
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||||
|
|
||||||
|
# 检查点目录
|
||||||
|
checkpoints = conf['training']['checkpoints']
|
||||||
|
os.makedirs(checkpoints, exist_ok=True)
|
||||||
|
|
||||||
|
# GradScaler for mixed precision
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
# 更新组件字典
|
||||||
|
components.update({
|
||||||
|
'train_dataloader': train_dataloader,
|
||||||
|
'val_dataloader': val_dataloader,
|
||||||
|
'model': model,
|
||||||
|
'metric': metric,
|
||||||
|
'criterion': criterion,
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'scheduler': scheduler,
|
||||||
|
'checkpoints': checkpoints,
|
||||||
|
'scaler': scaler,
|
||||||
|
'device': device
|
||||||
|
})
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
def run_training_loop(components):
|
def run_training_loop(components):
|
||||||
"""运行完整的训练循环"""
|
"""运行完整的训练循环"""
|
||||||
@ -262,9 +244,107 @@ def run_training_loop(components):
|
|||||||
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def main():
|
||||||
# 初始化训练组件
|
"""主函数入口"""
|
||||||
components = initialize_training_components()
|
# 加载配置
|
||||||
|
conf = load_configuration()
|
||||||
|
|
||||||
|
# 检查是否启用分布式训练
|
||||||
|
distributed = conf['base']['distributed']
|
||||||
|
|
||||||
|
if distributed:
|
||||||
|
# 分布式训练:使用mp.spawn启动多个进程
|
||||||
|
world_size = torch.cuda.device_count()
|
||||||
|
mp.spawn(
|
||||||
|
run_training,
|
||||||
|
args=(world_size, conf),
|
||||||
|
nprocs=world_size,
|
||||||
|
join=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 单机训练:直接运行训练流程
|
||||||
|
components = initialize_training_components(distributed=False)
|
||||||
|
run_training_loop(components)
|
||||||
|
|
||||||
|
|
||||||
|
def run_training(rank, world_size, conf):
|
||||||
|
"""实际执行训练的函数,供mp.spawn调用"""
|
||||||
|
# 初始化分布式环境
|
||||||
|
os.environ['RANK'] = str(rank)
|
||||||
|
os.environ['WORLD_SIZE'] = str(world_size)
|
||||||
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
|
os.environ['MASTER_PORT'] = '12355'
|
||||||
|
|
||||||
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
device = torch.device('cuda', rank)
|
||||||
|
|
||||||
|
# 创建数据加载器和模型等组件(分布式情况下)
|
||||||
|
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||||
|
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||||
|
|
||||||
|
# 初始化模型和度量
|
||||||
|
model, metric = initialize_model_and_metric(conf, class_num)
|
||||||
|
model = model.to(device)
|
||||||
|
metric = metric.to(device)
|
||||||
|
|
||||||
|
# 包装为DistributedDataParallel模型
|
||||||
|
model = DDP(model, device_ids=[rank], output_device=rank)
|
||||||
|
metric = DDP(metric, device_ids=[rank], output_device=rank)
|
||||||
|
|
||||||
|
# 设置损失函数、优化器和调度器
|
||||||
|
criterion = setup_loss_function(conf)
|
||||||
|
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||||
|
|
||||||
|
# 检查点目录
|
||||||
|
checkpoints = conf['training']['checkpoints']
|
||||||
|
os.makedirs(checkpoints, exist_ok=True)
|
||||||
|
|
||||||
|
# GradScaler for mixed precision
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
# 创建分布式数据加载器
|
||||||
|
train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True)
|
||||||
|
val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False)
|
||||||
|
|
||||||
|
# 重新创建适合分布式训练的数据加载器
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataloader.dataset,
|
||||||
|
batch_size=train_dataloader.batch_size,
|
||||||
|
sampler=train_sampler,
|
||||||
|
num_workers=train_dataloader.num_workers,
|
||||||
|
pin_memory=train_dataloader.pin_memory,
|
||||||
|
drop_last=train_dataloader.drop_last
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = torch.utils.data.DataLoader(
|
||||||
|
val_dataloader.dataset,
|
||||||
|
batch_size=val_dataloader.batch_size,
|
||||||
|
sampler=val_sampler,
|
||||||
|
num_workers=val_dataloader.num_workers,
|
||||||
|
pin_memory=val_dataloader.pin_memory,
|
||||||
|
drop_last=val_dataloader.drop_last
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建组件字典
|
||||||
|
components = {
|
||||||
|
'conf': conf,
|
||||||
|
'train_dataloader': train_dataloader,
|
||||||
|
'val_dataloader': val_dataloader,
|
||||||
|
'model': model,
|
||||||
|
'metric': metric,
|
||||||
|
'criterion': criterion,
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'scheduler': scheduler,
|
||||||
|
'checkpoints': checkpoints,
|
||||||
|
'scaler': scaler,
|
||||||
|
'device': device,
|
||||||
|
'distributed': True # 因为是在mp.spawn中运行
|
||||||
|
}
|
||||||
|
|
||||||
# 运行训练循环
|
# 运行训练循环
|
||||||
run_training_loop(components)
|
run_training_loop(components)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
Reference in New Issue
Block a user