Files
ieemoo-ai-contrast/train_compare.py
2025-08-07 11:00:36 +08:00

363 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import os.path as osp
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from model.loss import FocalLoss
from tools.dataset import load_data, MultiEpochsDataLoader
import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
from datetime import datetime
def load_configuration(config_path='configs/compare.yml'):
"""加载配置文件"""
with open(config_path, 'r') as f:
return yaml.load(f, Loader=yaml.FullLoader)
def initialize_model_and_metric(conf, class_num):
"""初始化模型和度量方法"""
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
metric_mapping = tr_tools.get_metric(class_num)
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]()
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
if conf['training']['metric'] in metric_mapping:
metric = metric_mapping[conf['training']['metric']]()
else:
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
return model, metric
def setup_optimizer_and_scheduler(conf, model, metric):
"""设置优化器和学习率调度器"""
tr_tools = trainer_tools(conf)
optimizer_mapping = tr_tools.get_optimizer(model, metric)
if conf['training']['optimizer'] in optimizer_mapping:
optimizer = optimizer_mapping[conf['training']['optimizer']]()
scheduler_mapping = tr_tools.get_scheduler(optimizer)
scheduler = scheduler_mapping[conf['training']['scheduler']]()
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
conf['training']['scheduler']))
return optimizer, scheduler
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
def setup_loss_function(conf):
"""配置损失函数"""
if conf['training']['loss'] == 'focal_loss':
return FocalLoss(gamma=2)
else:
return nn.CrossEntropyLoss()
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
"""执行单个训练周期"""
model.train()
train_loss = 0
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
data = data.to(device)
labels = labels.to(device)
# with torch.cuda.amp.autocast():
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
return train_loss / len(dataloader)
def validate(model, metric, criterion, dataloader, device, conf):
"""执行验证"""
model.eval()
val_loss = 0
with torch.no_grad():
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
data = data.to(device)
labels = labels.to(device)
embeddings = model(data)
if not conf['training']['metric'] == 'softmax':
thetas = metric(embeddings, labels)
else:
thetas = metric(embeddings)
loss = criterion(thetas, labels)
val_loss += loss.item()
return val_loss / len(dataloader)
def save_model(model, path, is_parallel):
"""保存模型权重"""
if is_parallel:
torch.save(model.module.state_dict(), path)
else:
torch.save(model.state_dict(), path)
def log_training_info(log_path, log_info):
"""记录训练信息到日志文件"""
with open(log_path, 'a') as f:
f.write(log_info + '\n')
def initialize_training_components(distributed=False):
"""初始化所有训练所需组件"""
# 加载配置
conf = load_configuration()
# 初始化分布式训练相关参数
components = {
'conf': conf,
'distributed': distributed,
'device': None,
'train_dataloader': None,
'val_dataloader': None,
'model': None,
'metric': None,
'criterion': None,
'optimizer': None,
'scheduler': None,
'checkpoints': None,
'scaler': None
}
# 如果是非分布式训练,直接创建所有组件
if not distributed:
# 数据加载
train_dataloader, class_num = load_data(training=True, cfg=conf, return_dataset=True)
val_dataloader, _ = load_data(training=False, cfg=conf, return_dataset=True)
train_dataloader = MultiEpochsDataLoader(train_dataloader,
batch_size=conf['data']['train_batch_size'],
shuffle=True,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=True)
val_dataloader = MultiEpochsDataLoader(val_dataloader,
batch_size=conf['data']['val_batch_size'],
shuffle=False,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=False)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
device = conf['base']['device']
model = model.to(device)
metric = metric.to(device)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
# 更新组件字典
components.update({
'train_dataloader': train_dataloader,
'val_dataloader': val_dataloader,
'model': model,
'metric': metric,
'criterion': criterion,
'optimizer': optimizer,
'scheduler': scheduler,
'checkpoints': checkpoints,
'scaler': scaler,
'device': device
})
return components
def run_training_loop(components):
"""运行完整的训练循环"""
# 解包组件
conf = components['conf']
train_dataloader = components['train_dataloader']
val_dataloader = components['val_dataloader']
model = components['model']
metric = components['metric']
criterion = components['criterion']
optimizer = components['optimizer']
scheduler = components['scheduler']
checkpoints = components['checkpoints']
scaler = components['scaler']
device = components['device']
# 训练状态
train_losses = []
val_losses = []
epochs = []
temp_loss = 100
if conf['training']['restore']:
print('load pretrain model: {}'.format(conf['training']['restore_model']))
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
# 训练循环
for e in range(conf['training']['epochs']):
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
train_losses.append(train_loss_avg)
epochs.append(e)
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
val_losses.append(val_loss_avg)
if val_loss_avg < temp_loss:
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
temp_loss = val_loss_avg
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
.format(datetime.now(),
e,
conf['training']['epochs'],
train_loss_avg,
val_loss_avg,
current_lr))
print(log_info)
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
print("%d个epoch的学习率%f" % (e, current_lr))
# 保存最终模型
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
# 绘制损失曲线
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
plt.legend()
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
def main():
"""主函数入口"""
# 加载配置
conf = load_configuration()
# 检查是否启用分布式训练
distributed = conf['base']['distributed']
if distributed:
# 分布式训练使用mp.spawn启动多个进程
world_size = torch.cuda.device_count()
mp.spawn(
run_training,
args=(world_size, conf),
nprocs=world_size,
join=True
)
else:
# 单机训练:直接运行训练流程
components = initialize_training_components(distributed=False)
run_training_loop(components)
def run_training(rank, world_size, conf):
"""实际执行训练的函数供mp.spawn调用"""
# 初始化分布式环境
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device('cuda', rank)
# 获取数据集而不是DataLoader
train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
val_dataset, _ = load_data(training=False, cfg=conf, return_dataset=True)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
model = model.to(device)
metric = metric.to(device)
# 包装为DistributedDataParallel模型
model = DDP(model, device_ids=[rank], output_device=rank)
metric = DDP(metric, device_ids=[rank], output_device=rank)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
# 创建分布式采样器
train_sampler = DistributedSampler(train_dataset, shuffle=True)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
# 使用 MultiEpochsDataLoader 创建分布式数据加载器
train_dataloader = MultiEpochsDataLoader(
train_dataset,
batch_size=conf['data']['train_batch_size'],
sampler=train_sampler,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=True
)
val_dataloader = MultiEpochsDataLoader(
val_dataset,
batch_size=conf['data']['val_batch_size'],
sampler=val_sampler,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=False
)
# 构建组件字典
components = {
'conf': conf,
'train_dataloader': train_dataloader,
'val_dataloader': val_dataloader,
'model': model,
'metric': metric,
'criterion': criterion,
'optimizer': optimizer,
'scheduler': scheduler,
'checkpoints': checkpoints,
'scaler': scaler,
'device': device,
'distributed': True # 因为是在mp.spawn中运行
}
# 运行训练循环
run_training_loop(components)
if __name__ == '__main__':
main()