143 lines
5.7 KiB
Python
143 lines
5.7 KiB
Python
import os
|
||
import os.path as osp
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from tqdm import tqdm
|
||
|
||
from model.loss import FocalLoss
|
||
from tools.dataset import load_data
|
||
import matplotlib.pyplot as plt
|
||
from configs import trainer_tools
|
||
import yaml
|
||
|
||
with open('configs/scatter.yml', 'r') as f:
|
||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||
|
||
# Data Setup
|
||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||
|
||
tr_tools = trainer_tools(conf)
|
||
backbone_mapping = tr_tools.get_backbone()
|
||
metric_mapping = tr_tools.get_metric(class_num)
|
||
|
||
if conf['models']['backbone'] in backbone_mapping:
|
||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||
else:
|
||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||
|
||
if conf['training']['metric'] in metric_mapping:
|
||
metric = metric_mapping[conf['training']['metric']]()
|
||
else:
|
||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||
|
||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||
model = nn.DataParallel(model)
|
||
metric = nn.DataParallel(metric)
|
||
|
||
# Training Setup
|
||
if conf['training']['loss'] == 'focal_loss':
|
||
criterion = FocalLoss(gamma=2)
|
||
else:
|
||
criterion = nn.CrossEntropyLoss()
|
||
|
||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||
if conf['training']['optimizer'] in optimizer_mapping:
|
||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||
scheduler = optim.lr_scheduler.StepLR(
|
||
optimizer,
|
||
step_size=conf['training']['lr_step'],
|
||
gamma=conf['training']['lr_decay']
|
||
)
|
||
else:
|
||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||
|
||
# Checkpoints Setup
|
||
checkpoints = conf['training']['checkpoints']
|
||
os.makedirs(checkpoints, exist_ok=True)
|
||
|
||
if __name__ == '__main__':
|
||
print('backbone>{} '.format(conf['models']['backbone']),
|
||
'metric>{} '.format(conf['training']['metric']),
|
||
'checkpoints>{} '.format(conf['training']['checkpoints']),
|
||
)
|
||
train_losses = []
|
||
val_losses = []
|
||
epochs = []
|
||
temp_loss = 100
|
||
if conf['training']['restore']:
|
||
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||
model.load_state_dict(torch.load(conf['training']['restore_model'],
|
||
map_location=conf['base']['device']))
|
||
|
||
for e in range(conf['training']['epochs']):
|
||
train_loss = 0
|
||
model.train()
|
||
|
||
for train_data, train_labels in tqdm(train_dataloader,
|
||
desc="Epoch {}/{}"
|
||
.format(e, conf['training']['epochs']),
|
||
ascii=True,
|
||
total=len(train_dataloader)):
|
||
train_data = train_data.to(conf['base']['device'])
|
||
train_labels = train_labels.to(conf['base']['device'])
|
||
|
||
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
|
||
# pdb.set_trace()
|
||
|
||
if not conf['training']['metric'] == 'softmax':
|
||
thetas = metric(train_embeddings, train_labels) # [256,357]
|
||
else:
|
||
thetas = metric(train_embeddings)
|
||
tloss = criterion(thetas, train_labels)
|
||
optimizer.zero_grad()
|
||
tloss.backward()
|
||
optimizer.step()
|
||
train_loss += tloss.item()
|
||
train_lossAvg = train_loss / len(train_dataloader)
|
||
train_losses.append(train_lossAvg)
|
||
epochs.append(e)
|
||
val_loss = 0
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for val_data, val_labels in tqdm(val_dataloader, desc="val",
|
||
ascii=True, total=len(val_dataloader)):
|
||
val_data = val_data.to(conf['base']['device'])
|
||
val_labels = val_labels.to(conf['base']['device'])
|
||
val_embeddings = model(val_data).to(conf['base']['device'])
|
||
if not conf['training']['metric'] == 'softmax':
|
||
thetas = metric(val_embeddings, val_labels)
|
||
else:
|
||
thetas = metric(val_embeddings)
|
||
vloss = criterion(thetas, val_labels)
|
||
val_loss += vloss.item()
|
||
val_lossAvg = val_loss / len(val_dataloader)
|
||
val_losses.append(val_lossAvg)
|
||
if val_lossAvg < temp_loss:
|
||
if torch.cuda.device_count() > 1:
|
||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||
else:
|
||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||
temp_loss = val_lossAvg
|
||
|
||
scheduler.step()
|
||
current_lr = optimizer.param_groups[0]['lr']
|
||
log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||
.format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
|
||
print(log_info)
|
||
# 写入日志文件
|
||
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
|
||
f.write(log_info + '\n')
|
||
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||
else:
|
||
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||
plt.plot(epochs, train_losses, color='blue')
|
||
plt.plot(epochs, val_losses, color='red')
|
||
# plt.savefig('lossMobilenetv3.png')
|
||
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|