Files
ieemoo-ai-contrast/train_compare.py
2025-06-11 15:23:50 +08:00

143 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import os.path as osp
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from model.loss import FocalLoss
from tools.dataset import load_data
import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
with open('configs/scatter.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Data Setup
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
metric_mapping = tr_tools.get_metric(class_num)
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
if conf['training']['metric'] in metric_mapping:
metric = metric_mapping[conf['training']['metric']]()
else:
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
metric = nn.DataParallel(metric)
# Training Setup
if conf['training']['loss'] == 'focal_loss':
criterion = FocalLoss(gamma=2)
else:
criterion = nn.CrossEntropyLoss()
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=conf['training']['lr_step'],
gamma=conf['training']['lr_decay']
)
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
# Checkpoints Setup
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
if __name__ == '__main__':
print('backbone>{} '.format(conf['models']['backbone']),
'metric>{} '.format(conf['training']['metric']),
'checkpoints>{} '.format(conf['training']['checkpoints']),
)
train_losses = []
val_losses = []
epochs = []
temp_loss = 100
if conf['training']['restore']:
print('load pretrain model: {}'.format(conf['training']['restore_model']))
model.load_state_dict(torch.load(conf['training']['restore_model'],
map_location=conf['base']['device']))
for e in range(conf['training']['epochs']):
train_loss = 0
model.train()
for train_data, train_labels in tqdm(train_dataloader,
desc="Epoch {}/{}"
.format(e, conf['training']['epochs']),
ascii=True,
total=len(train_dataloader)):
train_data = train_data.to(conf['base']['device'])
train_labels = train_labels.to(conf['base']['device'])
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
# pdb.set_trace()
if not conf['training']['metric'] == 'softmax':
thetas = metric(train_embeddings, train_labels) # [256,357]
else:
thetas = metric(train_embeddings)
tloss = criterion(thetas, train_labels)
optimizer.zero_grad()
tloss.backward()
optimizer.step()
train_loss += tloss.item()
train_lossAvg = train_loss / len(train_dataloader)
train_losses.append(train_lossAvg)
epochs.append(e)
val_loss = 0
model.eval()
with torch.no_grad():
for val_data, val_labels in tqdm(val_dataloader, desc="val",
ascii=True, total=len(val_dataloader)):
val_data = val_data.to(conf['base']['device'])
val_labels = val_labels.to(conf['base']['device'])
val_embeddings = model(val_data).to(conf['base']['device'])
if not conf['training']['metric'] == 'softmax':
thetas = metric(val_embeddings, val_labels)
else:
thetas = metric(val_embeddings)
vloss = criterion(thetas, val_labels)
val_loss += vloss.item()
val_lossAvg = val_loss / len(val_dataloader)
val_losses.append(val_lossAvg)
if val_lossAvg < temp_loss:
if torch.cuda.device_count() > 1:
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
else:
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
temp_loss = val_lossAvg
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
.format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
print(log_info)
# 写入日志文件
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
f.write(log_info + '\n')
print("%d个epoch的学习率%f" % (e, current_lr))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
else:
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
plt.plot(epochs, train_losses, color='blue')
plt.plot(epochs, val_losses, color='red')
# plt.savefig('lossMobilenetv3.png')
plt.savefig('loss/mobilenetv3Large_2250_0316.png')