This commit is contained in:
lee
2025-06-11 15:23:50 +08:00
commit 37ecef40f7
79 changed files with 26981 additions and 0 deletions

142
train_compare.py Normal file
View File

@ -0,0 +1,142 @@
import os
import os.path as osp
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from model.loss import FocalLoss
from tools.dataset import load_data
import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
with open('configs/scatter.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Data Setup
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
metric_mapping = tr_tools.get_metric(class_num)
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
if conf['training']['metric'] in metric_mapping:
metric = metric_mapping[conf['training']['metric']]()
else:
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
metric = nn.DataParallel(metric)
# Training Setup
if conf['training']['loss'] == 'focal_loss':
criterion = FocalLoss(gamma=2)
else:
criterion = nn.CrossEntropyLoss()
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=conf['training']['lr_step'],
gamma=conf['training']['lr_decay']
)
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
# Checkpoints Setup
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
if __name__ == '__main__':
print('backbone>{} '.format(conf['models']['backbone']),
'metric>{} '.format(conf['training']['metric']),
'checkpoints>{} '.format(conf['training']['checkpoints']),
)
train_losses = []
val_losses = []
epochs = []
temp_loss = 100
if conf['training']['restore']:
print('load pretrain model: {}'.format(conf['training']['restore_model']))
model.load_state_dict(torch.load(conf['training']['restore_model'],
map_location=conf['base']['device']))
for e in range(conf['training']['epochs']):
train_loss = 0
model.train()
for train_data, train_labels in tqdm(train_dataloader,
desc="Epoch {}/{}"
.format(e, conf['training']['epochs']),
ascii=True,
total=len(train_dataloader)):
train_data = train_data.to(conf['base']['device'])
train_labels = train_labels.to(conf['base']['device'])
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
# pdb.set_trace()
if not conf['training']['metric'] == 'softmax':
thetas = metric(train_embeddings, train_labels) # [256,357]
else:
thetas = metric(train_embeddings)
tloss = criterion(thetas, train_labels)
optimizer.zero_grad()
tloss.backward()
optimizer.step()
train_loss += tloss.item()
train_lossAvg = train_loss / len(train_dataloader)
train_losses.append(train_lossAvg)
epochs.append(e)
val_loss = 0
model.eval()
with torch.no_grad():
for val_data, val_labels in tqdm(val_dataloader, desc="val",
ascii=True, total=len(val_dataloader)):
val_data = val_data.to(conf['base']['device'])
val_labels = val_labels.to(conf['base']['device'])
val_embeddings = model(val_data).to(conf['base']['device'])
if not conf['training']['metric'] == 'softmax':
thetas = metric(val_embeddings, val_labels)
else:
thetas = metric(val_embeddings)
vloss = criterion(thetas, val_labels)
val_loss += vloss.item()
val_lossAvg = val_loss / len(val_dataloader)
val_losses.append(val_lossAvg)
if val_lossAvg < temp_loss:
if torch.cuda.device_count() > 1:
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
else:
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
temp_loss = val_lossAvg
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
log_info = ("Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
.format(e, conf['training']['epochs'], train_lossAvg, val_lossAvg, current_lr))
print(log_info)
# 写入日志文件
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
f.write(log_info + '\n')
print("%d个epoch的学习率%f" % (e, current_lr))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
else:
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
plt.plot(epochs, train_losses, color='blue')
plt.plot(epochs, val_losses, color='red')
# plt.savefig('lossMobilenetv3.png')
plt.savefig('loss/mobilenetv3Large_2250_0316.png')