363 lines
12 KiB
Python
363 lines
12 KiB
Python
import os
|
||
import os.path as osp
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from tqdm import tqdm
|
||
import torch.distributed as dist
|
||
import torch.multiprocessing as mp
|
||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||
from torch.utils.data.distributed import DistributedSampler
|
||
|
||
from model.loss import FocalLoss
|
||
from tools.dataset import load_data, MultiEpochsDataLoader
|
||
import matplotlib.pyplot as plt
|
||
from configs import trainer_tools
|
||
import yaml
|
||
from datetime import datetime
|
||
|
||
|
||
def load_configuration(config_path='configs/compare.yml'):
|
||
"""加载配置文件"""
|
||
with open(config_path, 'r') as f:
|
||
return yaml.load(f, Loader=yaml.FullLoader)
|
||
|
||
|
||
def initialize_model_and_metric(conf, class_num):
|
||
"""初始化模型和度量方法"""
|
||
tr_tools = trainer_tools(conf)
|
||
backbone_mapping = tr_tools.get_backbone()
|
||
metric_mapping = tr_tools.get_metric(class_num)
|
||
|
||
if conf['models']['backbone'] in backbone_mapping:
|
||
model = backbone_mapping[conf['models']['backbone']]()
|
||
else:
|
||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||
|
||
if conf['training']['metric'] in metric_mapping:
|
||
metric = metric_mapping[conf['training']['metric']]()
|
||
else:
|
||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||
|
||
return model, metric
|
||
|
||
|
||
def setup_optimizer_and_scheduler(conf, model, metric):
|
||
"""设置优化器和学习率调度器"""
|
||
tr_tools = trainer_tools(conf)
|
||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||
|
||
if conf['training']['optimizer'] in optimizer_mapping:
|
||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||
conf['training']['scheduler']))
|
||
return optimizer, scheduler
|
||
else:
|
||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||
|
||
|
||
def setup_loss_function(conf):
|
||
"""配置损失函数"""
|
||
if conf['training']['loss'] == 'focal_loss':
|
||
return FocalLoss(gamma=2)
|
||
else:
|
||
return nn.CrossEntropyLoss()
|
||
|
||
|
||
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
|
||
"""执行单个训练周期"""
|
||
model.train()
|
||
train_loss = 0
|
||
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
|
||
data = data.to(device)
|
||
labels = labels.to(device)
|
||
|
||
# with torch.cuda.amp.autocast():
|
||
embeddings = model(data)
|
||
if not conf['training']['metric'] == 'softmax':
|
||
thetas = metric(embeddings, labels)
|
||
else:
|
||
thetas = metric(embeddings)
|
||
loss = criterion(thetas, labels)
|
||
|
||
optimizer.zero_grad()
|
||
scaler.scale(loss).backward()
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
train_loss += loss.item()
|
||
return train_loss / len(dataloader)
|
||
|
||
|
||
def validate(model, metric, criterion, dataloader, device, conf):
|
||
"""执行验证"""
|
||
model.eval()
|
||
val_loss = 0
|
||
with torch.no_grad():
|
||
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
|
||
data = data.to(device)
|
||
labels = labels.to(device)
|
||
embeddings = model(data)
|
||
if not conf['training']['metric'] == 'softmax':
|
||
thetas = metric(embeddings, labels)
|
||
else:
|
||
thetas = metric(embeddings)
|
||
loss = criterion(thetas, labels)
|
||
val_loss += loss.item()
|
||
return val_loss / len(dataloader)
|
||
|
||
|
||
def save_model(model, path, is_parallel):
|
||
"""保存模型权重"""
|
||
if is_parallel:
|
||
torch.save(model.module.state_dict(), path)
|
||
else:
|
||
torch.save(model.state_dict(), path)
|
||
|
||
|
||
def log_training_info(log_path, log_info):
|
||
"""记录训练信息到日志文件"""
|
||
with open(log_path, 'a') as f:
|
||
f.write(log_info + '\n')
|
||
|
||
|
||
def initialize_training_components(distributed=False):
|
||
"""初始化所有训练所需组件"""
|
||
# 加载配置
|
||
conf = load_configuration()
|
||
|
||
# 初始化分布式训练相关参数
|
||
components = {
|
||
'conf': conf,
|
||
'distributed': distributed,
|
||
'device': None,
|
||
'train_dataloader': None,
|
||
'val_dataloader': None,
|
||
'model': None,
|
||
'metric': None,
|
||
'criterion': None,
|
||
'optimizer': None,
|
||
'scheduler': None,
|
||
'checkpoints': None,
|
||
'scaler': None
|
||
}
|
||
|
||
# 如果是非分布式训练,直接创建所有组件
|
||
if not distributed:
|
||
# 数据加载
|
||
train_dataloader, class_num = load_data(training=True, cfg=conf, return_dataset=True)
|
||
val_dataloader, _ = load_data(training=False, cfg=conf, return_dataset=True)
|
||
|
||
train_dataloader = MultiEpochsDataLoader(train_dataloader,
|
||
batch_size=conf['data']['train_batch_size'],
|
||
shuffle=True,
|
||
num_workers=conf['data']['num_workers'],
|
||
pin_memory=conf['base']['pin_memory'],
|
||
drop_last=True)
|
||
val_dataloader = MultiEpochsDataLoader(val_dataloader,
|
||
batch_size=conf['data']['val_batch_size'],
|
||
shuffle=False,
|
||
num_workers=conf['data']['num_workers'],
|
||
pin_memory=conf['base']['pin_memory'],
|
||
drop_last=False)
|
||
# 初始化模型和度量
|
||
model, metric = initialize_model_and_metric(conf, class_num)
|
||
device = conf['base']['device']
|
||
model = model.to(device)
|
||
metric = metric.to(device)
|
||
|
||
# 设置损失函数、优化器和调度器
|
||
criterion = setup_loss_function(conf)
|
||
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||
|
||
# 检查点目录
|
||
checkpoints = conf['training']['checkpoints']
|
||
os.makedirs(checkpoints, exist_ok=True)
|
||
|
||
# GradScaler for mixed precision
|
||
scaler = torch.cuda.amp.GradScaler()
|
||
|
||
# 更新组件字典
|
||
components.update({
|
||
'train_dataloader': train_dataloader,
|
||
'val_dataloader': val_dataloader,
|
||
'model': model,
|
||
'metric': metric,
|
||
'criterion': criterion,
|
||
'optimizer': optimizer,
|
||
'scheduler': scheduler,
|
||
'checkpoints': checkpoints,
|
||
'scaler': scaler,
|
||
'device': device
|
||
})
|
||
|
||
return components
|
||
|
||
|
||
def run_training_loop(components):
|
||
"""运行完整的训练循环"""
|
||
# 解包组件
|
||
conf = components['conf']
|
||
train_dataloader = components['train_dataloader']
|
||
val_dataloader = components['val_dataloader']
|
||
model = components['model']
|
||
metric = components['metric']
|
||
criterion = components['criterion']
|
||
optimizer = components['optimizer']
|
||
scheduler = components['scheduler']
|
||
checkpoints = components['checkpoints']
|
||
scaler = components['scaler']
|
||
device = components['device']
|
||
|
||
# 训练状态
|
||
train_losses = []
|
||
val_losses = []
|
||
epochs = []
|
||
temp_loss = 100
|
||
|
||
if conf['training']['restore']:
|
||
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
|
||
|
||
# 训练循环
|
||
for e in range(conf['training']['epochs']):
|
||
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
|
||
train_losses.append(train_loss_avg)
|
||
epochs.append(e)
|
||
|
||
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
|
||
val_losses.append(val_loss_avg)
|
||
|
||
if val_loss_avg < temp_loss:
|
||
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
|
||
temp_loss = val_loss_avg
|
||
|
||
scheduler.step()
|
||
current_lr = optimizer.param_groups[0]['lr']
|
||
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||
.format(datetime.now(),
|
||
e,
|
||
conf['training']['epochs'],
|
||
train_loss_avg,
|
||
val_loss_avg,
|
||
current_lr))
|
||
print(log_info)
|
||
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
|
||
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||
|
||
# 保存最终模型
|
||
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
|
||
|
||
# 绘制损失曲线
|
||
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
|
||
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
|
||
plt.legend()
|
||
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
||
|
||
|
||
def main():
|
||
"""主函数入口"""
|
||
# 加载配置
|
||
conf = load_configuration()
|
||
|
||
# 检查是否启用分布式训练
|
||
distributed = conf['base']['distributed']
|
||
|
||
if distributed:
|
||
# 分布式训练:使用mp.spawn启动多个进程
|
||
world_size = torch.cuda.device_count()
|
||
mp.spawn(
|
||
run_training,
|
||
args=(world_size, conf),
|
||
nprocs=world_size,
|
||
join=True
|
||
)
|
||
else:
|
||
# 单机训练:直接运行训练流程
|
||
components = initialize_training_components(distributed=False)
|
||
run_training_loop(components)
|
||
|
||
|
||
def run_training(rank, world_size, conf):
|
||
"""实际执行训练的函数,供mp.spawn调用"""
|
||
# 初始化分布式环境
|
||
os.environ['RANK'] = str(rank)
|
||
os.environ['WORLD_SIZE'] = str(world_size)
|
||
os.environ['MASTER_ADDR'] = 'localhost'
|
||
os.environ['MASTER_PORT'] = '12355'
|
||
|
||
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||
torch.cuda.set_device(rank)
|
||
device = torch.device('cuda', rank)
|
||
|
||
# 获取数据集而不是DataLoader
|
||
train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
|
||
val_dataset, _ = load_data(training=False, cfg=conf, return_dataset=True)
|
||
|
||
# 初始化模型和度量
|
||
model, metric = initialize_model_and_metric(conf, class_num)
|
||
model = model.to(device)
|
||
metric = metric.to(device)
|
||
|
||
# 包装为DistributedDataParallel模型
|
||
model = DDP(model, device_ids=[rank], output_device=rank)
|
||
metric = DDP(metric, device_ids=[rank], output_device=rank)
|
||
|
||
# 设置损失函数、优化器和调度器
|
||
criterion = setup_loss_function(conf)
|
||
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||
|
||
# 检查点目录
|
||
checkpoints = conf['training']['checkpoints']
|
||
os.makedirs(checkpoints, exist_ok=True)
|
||
|
||
# GradScaler for mixed precision
|
||
scaler = torch.cuda.amp.GradScaler()
|
||
|
||
# 创建分布式采样器
|
||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||
|
||
# 使用 MultiEpochsDataLoader 创建分布式数据加载器
|
||
train_dataloader = MultiEpochsDataLoader(
|
||
train_dataset,
|
||
batch_size=conf['data']['train_batch_size'],
|
||
sampler=train_sampler,
|
||
num_workers=conf['data']['num_workers'],
|
||
pin_memory=conf['base']['pin_memory'],
|
||
drop_last=True
|
||
)
|
||
|
||
val_dataloader = MultiEpochsDataLoader(
|
||
val_dataset,
|
||
batch_size=conf['data']['val_batch_size'],
|
||
sampler=val_sampler,
|
||
num_workers=conf['data']['num_workers'],
|
||
pin_memory=conf['base']['pin_memory'],
|
||
drop_last=False
|
||
)
|
||
|
||
# 构建组件字典
|
||
components = {
|
||
'conf': conf,
|
||
'train_dataloader': train_dataloader,
|
||
'val_dataloader': val_dataloader,
|
||
'model': model,
|
||
'metric': metric,
|
||
'criterion': criterion,
|
||
'optimizer': optimizer,
|
||
'scheduler': scheduler,
|
||
'checkpoints': checkpoints,
|
||
'scaler': scaler,
|
||
'device': device,
|
||
'distributed': True # 因为是在mp.spawn中运行
|
||
}
|
||
|
||
# 运行训练循环
|
||
run_training_loop(components)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|