import os import os.path as osp import torch import torch.nn as nn import torch.optim as optim from tqdm import tqdm from model.loss import FocalLoss from tools.dataset import load_data import matplotlib.pyplot as plt from configs import trainer_tools import yaml with open('configs/compare.yml', 'r') as f: conf = yaml.load(f, Loader=yaml.FullLoader) # Data Setup train_dataloader, class_num = load_data(training=True, cfg=conf) val_dataloader, _ = load_data(training=False, cfg=conf) tr_tools = trainer_tools(conf) backbone_mapping = tr_tools.get_backbone() metric_mapping = tr_tools.get_metric(class_num) if conf['models']['backbone'] in backbone_mapping: model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device']) else: raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']})) if conf['training']['metric'] in metric_mapping: metric = metric_mapping[conf['training']['metric']]() else: raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric'])) if torch.cuda.device_count() > 1 and conf['base']['distributed']: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) metric = nn.DataParallel(metric) # Training Setup if conf['training']['loss'] == 'focal_loss': criterion = FocalLoss(gamma=2) else: criterion = nn.CrossEntropyLoss() optimizer_mapping = tr_tools.get_optimizer(model, metric) if conf['training']['optimizer'] in optimizer_mapping: optimizer = optimizer_mapping[conf['training']['optimizer']]() scheduler_mapping = tr_tools.get_scheduler(optimizer) scheduler = scheduler_mapping[conf['training']['scheduler']]() print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'], conf['training']['scheduler'])) else: raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer'])) # Checkpoints Setup checkpoints = conf['training']['checkpoints'] os.makedirs(checkpoints, exist_ok=True) if __name__ == '__main__': print('backbone>{} '.format(conf['models']['backbone']), 'metric>{} '.format(conf['training']['metric']), 'checkpoints>{} '.format(conf['training']['checkpoints']), ) train_losses = [] val_losses = [] epochs = [] temp_loss = 100 if conf['training']['restore']: print('load pretrain model: {}'.format(conf['training']['restore_model'])) model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=conf['base']['device'])) for e in range(conf['training']['epochs']): train_loss = 0 model.train() for train_data, train_labels in tqdm(train_dataloader, desc="Epoch {}/{}" .format(e, conf['training']['epochs']), ascii=True, total=len(train_dataloader)): train_data = train_data.to(conf['base']['device']) train_labels = train_labels.to(conf['base']['device']) train_embeddings = model(train_data).to(conf['base']['device']) # [256,512] # pdb.set_trace() if not conf['training']['metric'] == 'softmax': thetas = metric(train_embeddings, train_labels) # [256,357] else: thetas = metric(train_embeddings) tloss = criterion(thetas, train_labels) optimizer.zero_grad() tloss.backward() optimizer.step() train_loss += tloss.item() train_lossAvg = train_loss / len(train_dataloader) train_losses.append(train_lossAvg) epochs.append(e) val_loss = 0 model.eval() with torch.no_grad(): for val_data, val_labels in tqdm(val_dataloader, desc="val", ascii=True, total=len(val_dataloader)): val_data = val_data.to(conf['base']['device']) val_labels = val_labels.to(conf['base']['device']) val_embeddings = model(val_data).to(conf['base']['device']) if not conf['training']['metric'] == 'softmax': thetas = metric(val_embeddings, val_labels) else: thetas = metric(val_embeddings) vloss = criterion(thetas, val_labels) val_loss += vloss.item() val_lossAvg = val_loss / len(val_dataloader) val_losses.append(val_lossAvg) if val_lossAvg < temp_loss: if torch.cuda.device_count() > 1: torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth')) else: torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth')) temp_loss = val_lossAvg scheduler.step() current_lr = optimizer.param_groups[0]['lr'] log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}" .format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr)) print(log_info) # 写入日志文件 with open(osp.join(conf['logging']['logging_dir']), 'a') as f: f.write(log_info + '\n') print("第%d个epoch的学习率:%f" % (e, current_lr)) if torch.cuda.device_count() > 1 and conf['base']['distributed']: torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth')) else: torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth')) plt.plot(epochs, train_losses, color='blue') plt.plot(epochs, val_losses, color='red') # plt.savefig('lossMobilenetv3.png') plt.savefig('loss/mobilenetv3Large_2250_0316.png')