并行训练代码优化

This commit is contained in:
lee
2025-07-03 14:20:37 +08:00
parent 5deaf4727f
commit bcbabd9313
2 changed files with 369 additions and 322 deletions

View File

@ -3,7 +3,6 @@ import os.path as osp
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from model.loss import FocalLoss
@ -12,116 +11,196 @@ import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
from datetime import datetime
with open('configs/scatter.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Data Setup
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
metric_mapping = tr_tools.get_metric(class_num)
def load_configuration(config_path='configs/scatter.yml'):
"""加载配置文件"""
with open(config_path, 'r') as f:
return yaml.load(f, Loader=yaml.FullLoader)
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
if conf['training']['metric'] in metric_mapping:
metric = metric_mapping[conf['training']['metric']]()
else:
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
def initialize_model_and_metric(conf, class_num):
"""初始化模型和度量方法"""
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
metric_mapping = tr_tools.get_metric(class_num)
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
metric = nn.DataParallel(metric)
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]()
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
# Training Setup
if conf['training']['loss'] == 'focal_loss':
criterion = FocalLoss(gamma=2)
else:
criterion = nn.CrossEntropyLoss()
if conf['training']['metric'] in metric_mapping:
metric = metric_mapping[conf['training']['metric']]()
else:
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler_mapping = tr_tools.get_scheduler(optimizer)
scheduler = scheduler_mapping[conf['training']['scheduler']]()
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
conf['training']['scheduler']))
return model, metric
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
# Checkpoints Setup
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
def setup_optimizer_and_scheduler(conf, model, metric):
"""设置优化器和学习率调度器"""
tr_tools = trainer_tools(conf)
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if __name__ == '__main__':
print('backbone>{} '.format(conf['models']['backbone']),
'metric>{} '.format(conf['training']['metric']),
'checkpoints>{} '.format(conf['training']['checkpoints']),
)
if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler_mapping = tr_tools.get_scheduler(optimizer)
scheduler = scheduler_mapping[conf['training']['scheduler']]()
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
conf['training']['scheduler']))
return optimizer, scheduler
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
def setup_loss_function(conf):
"""配置损失函数"""
if conf['training']['loss'] == 'focal_loss':
return FocalLoss(gamma=2)
else:
return nn.CrossEntropyLoss()
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
"""执行单个训练周期"""
model.train()
train_loss = 0
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
data = data.to(device)
labels = labels.to(device)
with torch.cuda.amp.autocast():
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
return train_loss / len(dataloader)
def validate(model, metric, criterion, dataloader, device, conf):
"""执行验证"""
model.eval()
val_loss = 0
with torch.no_grad():
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
data = data.to(device)
labels = labels.to(device)
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
val_loss += loss.item()
return val_loss / len(dataloader)
def save_model(model, path, is_parallel):
"""保存模型权重"""
if is_parallel:
torch.save(model.module.state_dict(), path)
else:
torch.save(model.state_dict(), path)
def log_training_info(log_path, log_info):
"""记录训练信息到日志文件"""
with open(log_path, 'a') as f:
f.write(log_info + '\n')
def initialize_training_components():
"""初始化所有训练所需组件"""
# 加载配置
conf = load_configuration()
# 数据加载
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
device = conf['base']['device']
model = model.to(device)
metric = metric.to(device)
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
metric = nn.DataParallel(metric)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
return {
'conf': conf,
'train_dataloader': train_dataloader,
'val_dataloader': val_dataloader,
'model': model,
'metric': metric,
'criterion': criterion,
'optimizer': optimizer,
'scheduler': scheduler,
'checkpoints': checkpoints,
'scaler': scaler,
'device': device
}
def run_training_loop(components):
"""运行完整的训练循环"""
# 解包组件
conf = components['conf']
train_dataloader = components['train_dataloader']
val_dataloader = components['val_dataloader']
model = components['model']
metric = components['metric']
criterion = components['criterion']
optimizer = components['optimizer']
scheduler = components['scheduler']
checkpoints = components['checkpoints']
scaler = components['scaler']
device = components['device']
# 训练状态
train_losses = []
val_losses = []
epochs = []
temp_loss = 100
if conf['training']['restore']:
print('load pretrain model: {}'.format(conf['training']['restore_model']))
model.load_state_dict(torch.load(conf['training']['restore_model'],
map_location=conf['base']['device']))
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
# 训练循环
for e in range(conf['training']['epochs']):
train_loss = 0
model.train()
for train_data, train_labels in tqdm(train_dataloader,
desc="Epoch {}/{}"
.format(e, conf['training']['epochs']),
ascii=True,
total=len(train_dataloader)):
train_data = train_data.to(conf['base']['device'])
train_labels = train_labels.to(conf['base']['device'])
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
# pdb.set_trace()
if not conf['training']['metric'] == 'softmax':
thetas = metric(train_embeddings, train_labels) # [256,357]
else:
thetas = metric(train_embeddings)
tloss = criterion(thetas, train_labels)
optimizer.zero_grad()
tloss.backward()
optimizer.step()
train_loss += tloss.item()
train_lossAvg = train_loss / len(train_dataloader)
train_losses.append(train_lossAvg)
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
train_losses.append(train_loss_avg)
epochs.append(e)
val_loss = 0
model.eval()
with torch.no_grad():
for val_data, val_labels in tqdm(val_dataloader, desc="val",
ascii=True, total=len(val_dataloader)):
val_data = val_data.to(conf['base']['device'])
val_labels = val_labels.to(conf['base']['device'])
val_embeddings = model(val_data).to(conf['base']['device'])
if not conf['training']['metric'] == 'softmax':
thetas = metric(val_embeddings, val_labels)
else:
thetas = metric(val_embeddings)
vloss = criterion(thetas, val_labels)
val_loss += vloss.item()
val_lossAvg = val_loss / len(val_dataloader)
val_losses.append(val_lossAvg)
if val_lossAvg < temp_loss:
if torch.cuda.device_count() > 1:
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
else:
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
temp_loss = val_lossAvg
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
val_losses.append(val_loss_avg)
if val_loss_avg < temp_loss:
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
temp_loss = val_loss_avg
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
@ -129,19 +208,26 @@ if __name__ == '__main__':
.format(datetime.now(),
e,
conf['training']['epochs'],
train_lossAvg,
val_lossAvg,
train_loss_avg,
val_loss_avg,
current_lr))
print(log_info)
# 写入日志文件
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
f.write(log_info + '\n')
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
print("%d个epoch的学习率%f" % (e, current_lr))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
else:
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
plt.plot(epochs, train_losses, color='blue')
plt.plot(epochs, val_losses, color='red')
# plt.savefig('lossMobilenetv3.png')
# 保存最终模型
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
# 绘制损失曲线
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
plt.legend()
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
if __name__ == '__main__':
# 初始化训练组件
components = initialize_training_components()
# 运行训练循环
run_training_loop(components)