import os import os.path as osp import torch import torch.nn as nn from tqdm import tqdm from model.loss import FocalLoss from tools.dataset import load_data import matplotlib.pyplot as plt from configs import trainer_tools import yaml from datetime import datetime def load_configuration(config_path='configs/scatter.yml'): """加载配置文件""" with open(config_path, 'r') as f: return yaml.load(f, Loader=yaml.FullLoader) def initialize_model_and_metric(conf, class_num): """初始化模型和度量方法""" tr_tools = trainer_tools(conf) backbone_mapping = tr_tools.get_backbone() metric_mapping = tr_tools.get_metric(class_num) if conf['models']['backbone'] in backbone_mapping: model = backbone_mapping[conf['models']['backbone']]() else: raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']})) if conf['training']['metric'] in metric_mapping: metric = metric_mapping[conf['training']['metric']]() else: raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric'])) return model, metric def setup_optimizer_and_scheduler(conf, model, metric): """设置优化器和学习率调度器""" tr_tools = trainer_tools(conf) optimizer_mapping = tr_tools.get_optimizer(model, metric) if conf['training']['optimizer'] in optimizer_mapping: optimizer = optimizer_mapping[conf['training']['optimizer']]() scheduler_mapping = tr_tools.get_scheduler(optimizer) scheduler = scheduler_mapping[conf['training']['scheduler']]() print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'], conf['training']['scheduler'])) return optimizer, scheduler else: raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer'])) def setup_loss_function(conf): """配置损失函数""" if conf['training']['loss'] == 'focal_loss': return FocalLoss(gamma=2) else: return nn.CrossEntropyLoss() def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf): """执行单个训练周期""" model.train() train_loss = 0 for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)): data = data.to(device) labels = labels.to(device) with torch.cuda.amp.autocast(): embeddings = model(data) if not conf['training']['metric'] == 'softmax': thetas = metric(embeddings, labels) else: thetas = metric(embeddings) loss = criterion(thetas, labels) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() train_loss += loss.item() return train_loss / len(dataloader) def validate(model, metric, criterion, dataloader, device, conf): """执行验证""" model.eval() val_loss = 0 with torch.no_grad(): for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)): data = data.to(device) labels = labels.to(device) embeddings = model(data) if not conf['training']['metric'] == 'softmax': thetas = metric(embeddings, labels) else: thetas = metric(embeddings) loss = criterion(thetas, labels) val_loss += loss.item() return val_loss / len(dataloader) def save_model(model, path, is_parallel): """保存模型权重""" if is_parallel: torch.save(model.module.state_dict(), path) else: torch.save(model.state_dict(), path) def log_training_info(log_path, log_info): """记录训练信息到日志文件""" with open(log_path, 'a') as f: f.write(log_info + '\n') def initialize_training_components(): """初始化所有训练所需组件""" # 加载配置 conf = load_configuration() # 数据加载 train_dataloader, class_num = load_data(training=True, cfg=conf) val_dataloader, _ = load_data(training=False, cfg=conf) # 初始化模型和度量 model, metric = initialize_model_and_metric(conf, class_num) device = conf['base']['device'] model = model.to(device) metric = metric.to(device) if torch.cuda.device_count() > 1 and conf['base']['distributed']: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) metric = nn.DataParallel(metric) # 设置损失函数、优化器和调度器 criterion = setup_loss_function(conf) optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric) # 检查点目录 checkpoints = conf['training']['checkpoints'] os.makedirs(checkpoints, exist_ok=True) # GradScaler for mixed precision scaler = torch.cuda.amp.GradScaler() return { 'conf': conf, 'train_dataloader': train_dataloader, 'val_dataloader': val_dataloader, 'model': model, 'metric': metric, 'criterion': criterion, 'optimizer': optimizer, 'scheduler': scheduler, 'checkpoints': checkpoints, 'scaler': scaler, 'device': device } def run_training_loop(components): """运行完整的训练循环""" # 解包组件 conf = components['conf'] train_dataloader = components['train_dataloader'] val_dataloader = components['val_dataloader'] model = components['model'] metric = components['metric'] criterion = components['criterion'] optimizer = components['optimizer'] scheduler = components['scheduler'] checkpoints = components['checkpoints'] scaler = components['scaler'] device = components['device'] # 训练状态 train_losses = [] val_losses = [] epochs = [] temp_loss = 100 if conf['training']['restore']: print('load pretrain model: {}'.format(conf['training']['restore_model'])) model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device)) # 训练循环 for e in range(conf['training']['epochs']): train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf) train_losses.append(train_loss_avg) epochs.append(e) val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf) val_losses.append(val_loss_avg) if val_loss_avg < temp_loss: save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel)) temp_loss = val_loss_avg scheduler.step() current_lr = optimizer.param_groups[0]['lr'] log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}" .format(datetime.now(), e, conf['training']['epochs'], train_loss_avg, val_loss_avg, current_lr)) print(log_info) log_training_info(osp.join(conf['logging']['logging_dir']), log_info) print("第%d个epoch的学习率:%f" % (e, current_lr)) # 保存最终模型 save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel)) # 绘制损失曲线 plt.plot(epochs, train_losses, color='blue', label='Train Loss') plt.plot(epochs, val_losses, color='red', label='Validation Loss') plt.legend() plt.savefig('loss/mobilenetv3Large_2250_0316.png') if __name__ == '__main__': # 初始化训练组件 components = initialize_training_components() # 运行训练循环 run_training_loop(components)