rebuild
This commit is contained in:
142
train_compare.py
Normal file
142
train_compare.py
Normal file
@ -0,0 +1,142 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
|
||||
from model.loss import FocalLoss
|
||||
from tools.dataset import load_data
|
||||
import matplotlib.pyplot as plt
|
||||
from configs import trainer_tools
|
||||
import yaml
|
||||
|
||||
with open('configs/scatter.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# Data Setup
|
||||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
metric_mapping = tr_tools.get_metric(class_num)
|
||||
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||||
else:
|
||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||
|
||||
if conf['training']['metric'] in metric_mapping:
|
||||
metric = metric_mapping[conf['training']['metric']]()
|
||||
else:
|
||||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||||
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = nn.DataParallel(model)
|
||||
metric = nn.DataParallel(metric)
|
||||
|
||||
# Training Setup
|
||||
if conf['training']['loss'] == 'focal_loss':
|
||||
criterion = FocalLoss(gamma=2)
|
||||
else:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||
if conf['training']['optimizer'] in optimizer_mapping:
|
||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||
scheduler = optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=conf['training']['lr_step'],
|
||||
gamma=conf['training']['lr_decay']
|
||||
)
|
||||
else:
|
||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||
|
||||
# Checkpoints Setup
|
||||
checkpoints = conf['training']['checkpoints']
|
||||
os.makedirs(checkpoints, exist_ok=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('backbone>{} '.format(conf['models']['backbone']),
|
||||
'metric>{} '.format(conf['training']['metric']),
|
||||
'checkpoints>{} '.format(conf['training']['checkpoints']),
|
||||
)
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
epochs = []
|
||||
temp_loss = 100
|
||||
if conf['training']['restore']:
|
||||
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||||
model.load_state_dict(torch.load(conf['training']['restore_model'],
|
||||
map_location=conf['base']['device']))
|
||||
|
||||
for e in range(conf['training']['epochs']):
|
||||
train_loss = 0
|
||||
model.train()
|
||||
|
||||
for train_data, train_labels in tqdm(train_dataloader,
|
||||
desc="Epoch {}/{}"
|
||||
.format(e, conf['training']['epochs']),
|
||||
ascii=True,
|
||||
total=len(train_dataloader)):
|
||||
train_data = train_data.to(conf['base']['device'])
|
||||
train_labels = train_labels.to(conf['base']['device'])
|
||||
|
||||
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
|
||||
# pdb.set_trace()
|
||||
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(train_embeddings, train_labels) # [256,357]
|
||||
else:
|
||||
thetas = metric(train_embeddings)
|
||||
tloss = criterion(thetas, train_labels)
|
||||
optimizer.zero_grad()
|
||||
tloss.backward()
|
||||
optimizer.step()
|
||||
train_loss += tloss.item()
|
||||
train_lossAvg = train_loss / len(train_dataloader)
|
||||
train_losses.append(train_lossAvg)
|
||||
epochs.append(e)
|
||||
val_loss = 0
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for val_data, val_labels in tqdm(val_dataloader, desc="val",
|
||||
ascii=True, total=len(val_dataloader)):
|
||||
val_data = val_data.to(conf['base']['device'])
|
||||
val_labels = val_labels.to(conf['base']['device'])
|
||||
val_embeddings = model(val_data).to(conf['base']['device'])
|
||||
if not conf['training']['metric'] == 'softmax':
|
||||
thetas = metric(val_embeddings, val_labels)
|
||||
else:
|
||||
thetas = metric(val_embeddings)
|
||||
vloss = criterion(thetas, val_labels)
|
||||
val_loss += vloss.item()
|
||||
val_lossAvg = val_loss / len(val_dataloader)
|
||||
val_losses.append(val_lossAvg)
|
||||
if val_lossAvg < temp_loss:
|
||||
if torch.cuda.device_count() > 1:
|
||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||||
else:
|
||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
||||
temp_loss = val_lossAvg
|
||||
|
||||
scheduler.step()
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||||
.format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
|
||||
print(log_info)
|
||||
# 写入日志文件
|
||||
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
|
||||
f.write(log_info + '\n')
|
||||
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||||
else:
|
||||
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
|
||||
plt.plot(epochs, train_losses, color='blue')
|
||||
plt.plot(epochs, val_losses, color='red')
|
||||
# plt.savefig('lossMobilenetv3.png')
|
||||
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
Reference in New Issue
Block a user