并行训练代码优化
This commit is contained in:
397
train_compare.py
397
train_compare.py
@ -3,12 +3,11 @@ import os.path as osp
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.utils.data.distributed
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from model.loss import FocalLoss
|
from model.loss import FocalLoss
|
||||||
from tools.dataset import load_data
|
from tools.dataset import load_data
|
||||||
@ -18,146 +17,88 @@ import yaml
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path='configs/scatter.yml'):
|
def load_configuration(config_path='configs/scatter.yml'):
|
||||||
"""加载配置文件."""
|
"""加载配置文件"""
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
return yaml.load(f, Loader=yaml.FullLoader)
|
return yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
|
||||||
# 加载配置
|
def initialize_model_and_metric(conf, class_num):
|
||||||
conf = load_config()
|
"""初始化模型和度量方法"""
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
|
backbone_mapping = tr_tools.get_backbone()
|
||||||
|
metric_mapping = tr_tools.get_metric(class_num)
|
||||||
|
|
||||||
|
if conf['models']['backbone'] in backbone_mapping:
|
||||||
# 数据加载封装
|
model = backbone_mapping[conf['models']['backbone']]()
|
||||||
def load_datasets():
|
|
||||||
"""加载训练和验证数据集,并为分布式训练创建 DistributedSampler."""
|
|
||||||
train_dataset, class_num = load_data(training=True, cfg=conf)
|
|
||||||
val_dataset, _ = load_data(training=False, cfg=conf)
|
|
||||||
|
|
||||||
if conf['base']['distributed']:
|
|
||||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
|
||||||
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
|
|
||||||
else:
|
else:
|
||||||
train_sampler = None
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||||
val_sampler = None
|
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
if conf['training']['metric'] in metric_mapping:
|
||||||
train_dataset,
|
metric = metric_mapping[conf['training']['metric']]()
|
||||||
batch_size=conf['data']['batch_size'],
|
|
||||||
shuffle=(train_sampler is None),
|
|
||||||
num_workers=conf['data']['num_workers'],
|
|
||||||
pin_memory=conf['data']['pin_memory'],
|
|
||||||
sampler=train_sampler
|
|
||||||
)
|
|
||||||
|
|
||||||
val_dataloader = torch.utils.data.DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=conf['data']['batch_size'],
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=conf['data']['num_workers'],
|
|
||||||
pin_memory=conf['data']['pin_memory'],
|
|
||||||
sampler=val_sampler
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_dataloader, val_dataloader, class_num
|
|
||||||
|
|
||||||
|
|
||||||
# 加载数据集
|
|
||||||
train_dataloader, val_dataloader, class_num = load_datasets()
|
|
||||||
|
|
||||||
tr_tools = trainer_tools(conf)
|
|
||||||
backbone_mapping = tr_tools.get_backbone()
|
|
||||||
metric_mapping = tr_tools.get_metric(class_num)
|
|
||||||
|
|
||||||
|
|
||||||
# 设备管理封装
|
|
||||||
def get_device(device_config=None):
|
|
||||||
"""根据配置返回设备(CPU/GPU)。在分布式环境下初始化进程组."""
|
|
||||||
if device_config is None:
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
else:
|
else:
|
||||||
device = torch.device(device_config)
|
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||||||
|
|
||||||
if conf['base']['distributed']:
|
return model, metric
|
||||||
dist.init_process_group(backend='nccl')
|
|
||||||
|
|
||||||
return device
|
|
||||||
|
|
||||||
|
|
||||||
# 获取设备
|
def setup_optimizer_and_scheduler(conf, model, metric):
|
||||||
device = get_device(conf['base']['device'])
|
"""设置优化器和学习率调度器"""
|
||||||
|
tr_tools = trainer_tools(conf)
|
||||||
# 模型初始化
|
|
||||||
if conf['models']['backbone'] in backbone_mapping:
|
|
||||||
model = backbone_mapping[conf['models']['backbone']]().to(device)
|
|
||||||
else:
|
|
||||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
|
||||||
|
|
||||||
if conf['training']['metric'] in metric_mapping:
|
|
||||||
metric = metric_mapping[conf['training']['metric']]().to(device)
|
|
||||||
else:
|
|
||||||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
|
||||||
|
|
||||||
rank = 0
|
|
||||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
|
||||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
|
||||||
dist.barrier()
|
|
||||||
model = DDP(model, device_ids=[rank])
|
|
||||||
metric = DDP(metric, device_ids=[rank])
|
|
||||||
|
|
||||||
# Training Setup
|
|
||||||
def initialize_components():
|
|
||||||
# 封装模型、损失函数和优化器的初始化
|
|
||||||
if conf['training']['loss'] == 'focal_loss':
|
|
||||||
criterion = FocalLoss(gamma=2)
|
|
||||||
else:
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||||
|
|
||||||
if conf['training']['optimizer'] in optimizer_mapping:
|
if conf['training']['optimizer'] in optimizer_mapping:
|
||||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||||
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||||
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||||
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||||
conf['training']['scheduler']))
|
conf['training']['scheduler']))
|
||||||
return criterion, optimizer, scheduler
|
return optimizer, scheduler
|
||||||
else:
|
else:
|
||||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||||
|
|
||||||
|
|
||||||
# 初始化组件
|
def setup_loss_function(conf):
|
||||||
criterion, optimizer, scheduler = initialize_components()
|
"""配置损失函数"""
|
||||||
|
if conf['training']['loss'] == 'focal_loss':
|
||||||
# Checkpoints Setup
|
return FocalLoss(gamma=2)
|
||||||
checkpoints = conf['training']['checkpoints']
|
else:
|
||||||
os.makedirs(checkpoints, exist_ok=True)
|
return nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(model, dataloader, optimizer, criterion, device):
|
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
|
||||||
|
"""执行单个训练周期"""
|
||||||
model.train()
|
model.train()
|
||||||
train_loss = 0
|
train_loss = 0
|
||||||
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
|
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
|
||||||
data, labels = data.to(device), labels.to(device)
|
data = data.to(device)
|
||||||
embeddings = model(data).to(device)
|
labels = labels.to(device)
|
||||||
if not conf['training']['metric'] == 'softmax':
|
|
||||||
thetas = metric(embeddings, labels)
|
with torch.cuda.amp.autocast():
|
||||||
else:
|
embeddings = model(data)
|
||||||
thetas = metric(embeddings)
|
if not conf['training']['metric'] == 'softmax':
|
||||||
loss = criterion(thetas, labels)
|
thetas = metric(embeddings, labels)
|
||||||
|
else:
|
||||||
|
thetas = metric(embeddings)
|
||||||
|
loss = criterion(thetas, labels)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
scaler.scale(loss).backward()
|
||||||
optimizer.step()
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
train_loss += loss.item()
|
train_loss += loss.item()
|
||||||
return train_loss / len(dataloader)
|
return train_loss / len(dataloader)
|
||||||
|
|
||||||
|
|
||||||
def validate_epoch(model, dataloader, criterion, device):
|
def validate(model, metric, criterion, dataloader, device, conf):
|
||||||
|
"""执行验证"""
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data, labels in tqdm(dataloader, desc="Validation", ascii=True, total=len(dataloader)):
|
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
|
||||||
data, labels = data.to(device), labels.to(device)
|
data = data.to(device)
|
||||||
embeddings = model(data).to(device)
|
labels = labels.to(device)
|
||||||
|
embeddings = model(data)
|
||||||
if not conf['training']['metric'] == 'softmax':
|
if not conf['training']['metric'] == 'softmax':
|
||||||
thetas = metric(embeddings, labels)
|
thetas = metric(embeddings, labels)
|
||||||
else:
|
else:
|
||||||
@ -167,143 +108,163 @@ def validate_epoch(model, dataloader, criterion, device):
|
|||||||
return val_loss / len(dataloader)
|
return val_loss / len(dataloader)
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, path, distributed):
|
def save_model(model, path, is_parallel):
|
||||||
if distributed and torch.cuda.device_count() > 1:
|
"""保存模型权重"""
|
||||||
if dist.get_rank() == 0:
|
if is_parallel:
|
||||||
torch.save(model.module.state_dict(), path)
|
torch.save(model.module.state_dict(), path)
|
||||||
else:
|
else:
|
||||||
torch.save(model.state_dict(), path)
|
torch.save(model.state_dict(), path)
|
||||||
|
|
||||||
|
|
||||||
def write_log(log_info, log_dir):
|
def log_training_info(log_path, log_info):
|
||||||
with open(log_dir, 'a') as f:
|
"""记录训练信息到日志文件"""
|
||||||
|
with open(log_path, 'a') as f:
|
||||||
f.write(log_info + '\n')
|
f.write(log_info + '\n')
|
||||||
|
|
||||||
|
|
||||||
def plot_losses(epochs, train_losses, val_losses, save_path):
|
def initialize_training_components():
|
||||||
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
|
"""初始化所有训练所需组件"""
|
||||||
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
|
# 加载配置
|
||||||
plt.xlabel('Epochs')
|
conf = load_configuration()
|
||||||
plt.ylabel('Loss')
|
|
||||||
plt.legend()
|
|
||||||
plt.savefig(save_path)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
|
# 初始化分布式训练
|
||||||
# 模型恢复封装
|
distributed = conf['base']['distributed']
|
||||||
def restore_model(model, device):
|
if distributed:
|
||||||
if conf['training']['restore']:
|
dist.init_process_group(backend='nccl')
|
||||||
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
|
torch.cuda.set_device(local_rank)
|
||||||
return model
|
device = torch.device('cuda', local_rank)
|
||||||
|
|
||||||
|
|
||||||
# 日志和学习率记录封装
|
|
||||||
def log_and_print(e, train_loss_avg, val_loss_avg, current_lr, log_dir):
|
|
||||||
if conf['base']['distributed'] and dist.get_rank() != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
|
||||||
.format(datetime.now(),
|
|
||||||
e,
|
|
||||||
conf['training']['epochs'],
|
|
||||||
train_loss_avg,
|
|
||||||
val_loss_avg,
|
|
||||||
current_lr))
|
|
||||||
print(log_info)
|
|
||||||
write_log(log_info, log_dir)
|
|
||||||
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
|
||||||
|
|
||||||
|
|
||||||
# 模型评估与保存封装
|
|
||||||
def evaluate_and_save(val_loss_avg, best_loss, model, checkpoints, distributed):
|
|
||||||
if val_loss_avg < best_loss:
|
|
||||||
best_path = osp.join(checkpoints, 'best.pth')
|
|
||||||
save_model(model, best_path, distributed)
|
|
||||||
best_loss = val_loss_avg
|
|
||||||
return best_loss
|
|
||||||
|
|
||||||
|
|
||||||
def run_training(rank, world_size, conf):
|
|
||||||
"""在指定 rank 上运行训练 loop。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rank: 当前进程的索引。
|
|
||||||
world_size: 进程总数。
|
|
||||||
conf: 配置字典。
|
|
||||||
"""
|
|
||||||
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
|
||||||
device = torch.device('cuda', rank)
|
|
||||||
|
|
||||||
# 数据加载器和模型等需要重新初始化以确保每个进程独立工作
|
|
||||||
train_dataloader, val_dataloader, class_num = load_datasets()
|
|
||||||
tr_tools = trainer_tools(conf)
|
|
||||||
backbone_mapping = tr_tools.get_backbone()
|
|
||||||
metric_mapping = tr_tools.get_metric(class_num)
|
|
||||||
|
|
||||||
# 模型初始化
|
|
||||||
if conf['models']['backbone'] in backbone_mapping:
|
|
||||||
model = backbone_mapping[conf['models']['backbone']]().to(device)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
device = conf['base']['device']
|
||||||
|
|
||||||
if conf['training']['metric'] in metric_mapping:
|
# 数据加载
|
||||||
metric = metric_mapping[conf['training']['metric']]().to(device)
|
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||||
else:
|
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
|
||||||
|
|
||||||
model = DDP(model, device_ids=[rank])
|
# 如果使用分布式,需要为每个进程创建单独的数据加载器
|
||||||
metric = DDP(metric, device_ids=[rank])
|
if distributed:
|
||||||
|
train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True)
|
||||||
|
val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False)
|
||||||
|
|
||||||
# 初始化组件
|
# 重新创建适合分布式训练的数据加载器
|
||||||
criterion, optimizer, scheduler = initialize_components()
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataloader.dataset,
|
||||||
|
batch_size=train_dataloader.batch_size,
|
||||||
|
sampler=train_sampler,
|
||||||
|
num_workers=train_dataloader.num_workers,
|
||||||
|
pin_memory=train_dataloader.pin_memory,
|
||||||
|
drop_last=train_dataloader.drop_last
|
||||||
|
)
|
||||||
|
|
||||||
# 恢复模型(如果需要)
|
val_dataloader = torch.utils.data.DataLoader(
|
||||||
model = restore_model(model, device)
|
val_dataloader.dataset,
|
||||||
|
batch_size=val_dataloader.batch_size,
|
||||||
|
sampler=val_sampler,
|
||||||
|
num_workers=val_dataloader.num_workers,
|
||||||
|
pin_memory=val_dataloader.pin_memory,
|
||||||
|
drop_last=val_dataloader.drop_last
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化模型和度量
|
||||||
|
model, metric = initialize_model_and_metric(conf, class_num)
|
||||||
|
model = model.to(device)
|
||||||
|
metric = metric.to(device)
|
||||||
|
|
||||||
|
if distributed:
|
||||||
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
||||||
|
metric = DDP(metric, device_ids=[local_rank], output_device=local_rank)
|
||||||
|
|
||||||
|
# 设置损失函数、优化器和调度器
|
||||||
|
criterion = setup_loss_function(conf)
|
||||||
|
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||||
|
|
||||||
|
# 检查点目录
|
||||||
|
checkpoints = conf['training']['checkpoints']
|
||||||
|
os.makedirs(checkpoints, exist_ok=True)
|
||||||
|
|
||||||
|
# GradScaler for mixed precision
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'conf': conf,
|
||||||
|
'train_dataloader': train_dataloader,
|
||||||
|
'val_dataloader': val_dataloader,
|
||||||
|
'model': model,
|
||||||
|
'metric': metric,
|
||||||
|
'criterion': criterion,
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'scheduler': scheduler,
|
||||||
|
'checkpoints': checkpoints,
|
||||||
|
'scaler': scaler,
|
||||||
|
'device': device,
|
||||||
|
'distributed': distributed
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_training_loop(components):
|
||||||
|
"""运行完整的训练循环"""
|
||||||
|
# 解包组件
|
||||||
|
conf = components['conf']
|
||||||
|
train_dataloader = components['train_dataloader']
|
||||||
|
val_dataloader = components['val_dataloader']
|
||||||
|
model = components['model']
|
||||||
|
metric = components['metric']
|
||||||
|
criterion = components['criterion']
|
||||||
|
optimizer = components['optimizer']
|
||||||
|
scheduler = components['scheduler']
|
||||||
|
checkpoints = components['checkpoints']
|
||||||
|
scaler = components['scaler']
|
||||||
|
device = components['device']
|
||||||
|
|
||||||
|
# 训练状态
|
||||||
train_losses = []
|
train_losses = []
|
||||||
val_losses = []
|
val_losses = []
|
||||||
epochs = []
|
epochs = []
|
||||||
temp_loss = 1000
|
temp_loss = 100
|
||||||
log_dir = osp.join(conf['logging']['logging_dir'])
|
|
||||||
|
|
||||||
|
if conf['training']['restore']:
|
||||||
|
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||||||
|
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
for e in range(conf['training']['epochs']):
|
for e in range(conf['training']['epochs']):
|
||||||
train_loss_avg = train_epoch(model, train_dataloader, optimizer, criterion, device)
|
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
|
||||||
train_losses.append(train_loss_avg)
|
train_losses.append(train_loss_avg)
|
||||||
|
epochs.append(e)
|
||||||
|
|
||||||
val_loss_avg = validate_epoch(model, val_dataloader, criterion, device)
|
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
|
||||||
val_losses.append(val_loss_avg)
|
val_losses.append(val_loss_avg)
|
||||||
|
|
||||||
temp_loss = evaluate_and_save(val_loss_avg, temp_loss, model, conf['training']['checkpoints'], True)
|
if val_loss_avg < temp_loss:
|
||||||
|
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
|
||||||
|
temp_loss = val_loss_avg
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
log_and_print(e, train_loss_avg, val_loss_avg, current_lr, log_dir)
|
log_info = ("[{:%Y-%m-%d %H:%M:%S}] Epoch {}/{}, train_loss: {}, val_loss: {} lr:{}"
|
||||||
epochs.append(e)
|
.format(datetime.now(),
|
||||||
|
e,
|
||||||
|
conf['training']['epochs'],
|
||||||
|
train_loss_avg,
|
||||||
|
val_loss_avg,
|
||||||
|
current_lr))
|
||||||
|
print(log_info)
|
||||||
|
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
|
||||||
|
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||||||
|
|
||||||
last_path = osp.join(conf['training']['checkpoints'], 'last.pth')
|
# 保存最终模型
|
||||||
save_model(model, last_path, True)
|
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
|
||||||
plot_losses(epochs, train_losses, val_losses, 'loss/mobilenetv3Large_2250_0316.png')
|
|
||||||
|
|
||||||
dist.destroy_process_group()
|
# 绘制损失曲线
|
||||||
|
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
|
||||||
|
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
|
||||||
def main():
|
plt.legend()
|
||||||
world_size = torch.cuda.device_count()
|
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
||||||
mp.spawn(
|
|
||||||
run_training,
|
|
||||||
args=(world_size, conf),
|
|
||||||
nprocs=world_size,
|
|
||||||
join=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print('backbone>{} '.format(conf['models']['backbone']),
|
# 初始化训练组件
|
||||||
'metric>{} '.format(conf['training']['metric']),
|
components = initialize_training_components()
|
||||||
'checkpoints>{} '.format(conf['training']['checkpoints']),
|
|
||||||
)
|
# 运行训练循环
|
||||||
if conf['base']['distributed']:
|
run_training_loop(components)
|
||||||
main()
|
|
||||||
else:
|
|
||||||
run_training(rank=0, world_size=1, conf=conf)
|
|
@ -3,7 +3,6 @@ import os.path as osp
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from model.loss import FocalLoss
|
from model.loss import FocalLoss
|
||||||
@ -12,116 +11,196 @@ import matplotlib.pyplot as plt
|
|||||||
from configs import trainer_tools
|
from configs import trainer_tools
|
||||||
import yaml
|
import yaml
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
with open('configs/scatter.yml', 'r') as f:
|
|
||||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
|
|
||||||
# Data Setup
|
|
||||||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
|
||||||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
|
||||||
|
|
||||||
tr_tools = trainer_tools(conf)
|
def load_configuration(config_path='configs/scatter.yml'):
|
||||||
backbone_mapping = tr_tools.get_backbone()
|
"""加载配置文件"""
|
||||||
metric_mapping = tr_tools.get_metric(class_num)
|
with open(config_path, 'r') as f:
|
||||||
|
return yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
if conf['models']['backbone'] in backbone_mapping:
|
|
||||||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
|
||||||
else:
|
|
||||||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
|
||||||
|
|
||||||
if conf['training']['metric'] in metric_mapping:
|
def initialize_model_and_metric(conf, class_num):
|
||||||
metric = metric_mapping[conf['training']['metric']]()
|
"""初始化模型和度量方法"""
|
||||||
else:
|
tr_tools = trainer_tools(conf)
|
||||||
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
backbone_mapping = tr_tools.get_backbone()
|
||||||
|
metric_mapping = tr_tools.get_metric(class_num)
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
if conf['models']['backbone'] in backbone_mapping:
|
||||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
model = backbone_mapping[conf['models']['backbone']]()
|
||||||
model = nn.DataParallel(model)
|
else:
|
||||||
metric = nn.DataParallel(metric)
|
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||||||
|
|
||||||
# Training Setup
|
if conf['training']['metric'] in metric_mapping:
|
||||||
if conf['training']['loss'] == 'focal_loss':
|
metric = metric_mapping[conf['training']['metric']]()
|
||||||
criterion = FocalLoss(gamma=2)
|
else:
|
||||||
else:
|
raise ValueError('不支持的metric类型: {}'.format(conf['training']['metric']))
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
return model, metric
|
||||||
if conf['training']['optimizer'] in optimizer_mapping:
|
|
||||||
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
|
||||||
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
|
||||||
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
|
||||||
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
|
||||||
conf['training']['scheduler']))
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
|
||||||
|
|
||||||
# Checkpoints Setup
|
def setup_optimizer_and_scheduler(conf, model, metric):
|
||||||
checkpoints = conf['training']['checkpoints']
|
"""设置优化器和学习率调度器"""
|
||||||
os.makedirs(checkpoints, exist_ok=True)
|
tr_tools = trainer_tools(conf)
|
||||||
|
optimizer_mapping = tr_tools.get_optimizer(model, metric)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if conf['training']['optimizer'] in optimizer_mapping:
|
||||||
print('backbone>{} '.format(conf['models']['backbone']),
|
optimizer = optimizer_mapping[conf['training']['optimizer']]()
|
||||||
'metric>{} '.format(conf['training']['metric']),
|
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||||
'checkpoints>{} '.format(conf['training']['checkpoints']),
|
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||||
)
|
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||||
|
conf['training']['scheduler']))
|
||||||
|
return optimizer, scheduler
|
||||||
|
else:
|
||||||
|
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||||
|
|
||||||
|
|
||||||
|
def setup_loss_function(conf):
|
||||||
|
"""配置损失函数"""
|
||||||
|
if conf['training']['loss'] == 'focal_loss':
|
||||||
|
return FocalLoss(gamma=2)
|
||||||
|
else:
|
||||||
|
return nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(model, metric, criterion, optimizer, dataloader, device, scaler, conf):
|
||||||
|
"""执行单个训练周期"""
|
||||||
|
model.train()
|
||||||
|
train_loss = 0
|
||||||
|
for data, labels in tqdm(dataloader, desc="Training", ascii=True, total=len(dataloader)):
|
||||||
|
data = data.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
embeddings = model(data)
|
||||||
|
if not conf['training']['metric'] == 'softmax':
|
||||||
|
thetas = metric(embeddings, labels)
|
||||||
|
else:
|
||||||
|
thetas = metric(embeddings)
|
||||||
|
loss = criterion(thetas, labels)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
train_loss += loss.item()
|
||||||
|
return train_loss / len(dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
def validate(model, metric, criterion, dataloader, device, conf):
|
||||||
|
"""执行验证"""
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, labels in tqdm(dataloader, desc="Validating", ascii=True, total=len(dataloader)):
|
||||||
|
data = data.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
embeddings = model(data)
|
||||||
|
if not conf['training']['metric'] == 'softmax':
|
||||||
|
thetas = metric(embeddings, labels)
|
||||||
|
else:
|
||||||
|
thetas = metric(embeddings)
|
||||||
|
loss = criterion(thetas, labels)
|
||||||
|
val_loss += loss.item()
|
||||||
|
return val_loss / len(dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model, path, is_parallel):
|
||||||
|
"""保存模型权重"""
|
||||||
|
if is_parallel:
|
||||||
|
torch.save(model.module.state_dict(), path)
|
||||||
|
else:
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
|
||||||
|
|
||||||
|
def log_training_info(log_path, log_info):
|
||||||
|
"""记录训练信息到日志文件"""
|
||||||
|
with open(log_path, 'a') as f:
|
||||||
|
f.write(log_info + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_training_components():
|
||||||
|
"""初始化所有训练所需组件"""
|
||||||
|
# 加载配置
|
||||||
|
conf = load_configuration()
|
||||||
|
|
||||||
|
# 数据加载
|
||||||
|
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
||||||
|
val_dataloader, _ = load_data(training=False, cfg=conf)
|
||||||
|
|
||||||
|
# 初始化模型和度量
|
||||||
|
model, metric = initialize_model_and_metric(conf, class_num)
|
||||||
|
device = conf['base']['device']
|
||||||
|
model = model.to(device)
|
||||||
|
metric = metric.to(device)
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||||
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
metric = nn.DataParallel(metric)
|
||||||
|
|
||||||
|
# 设置损失函数、优化器和调度器
|
||||||
|
criterion = setup_loss_function(conf)
|
||||||
|
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||||
|
|
||||||
|
# 检查点目录
|
||||||
|
checkpoints = conf['training']['checkpoints']
|
||||||
|
os.makedirs(checkpoints, exist_ok=True)
|
||||||
|
|
||||||
|
# GradScaler for mixed precision
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'conf': conf,
|
||||||
|
'train_dataloader': train_dataloader,
|
||||||
|
'val_dataloader': val_dataloader,
|
||||||
|
'model': model,
|
||||||
|
'metric': metric,
|
||||||
|
'criterion': criterion,
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'scheduler': scheduler,
|
||||||
|
'checkpoints': checkpoints,
|
||||||
|
'scaler': scaler,
|
||||||
|
'device': device
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_training_loop(components):
|
||||||
|
"""运行完整的训练循环"""
|
||||||
|
# 解包组件
|
||||||
|
conf = components['conf']
|
||||||
|
train_dataloader = components['train_dataloader']
|
||||||
|
val_dataloader = components['val_dataloader']
|
||||||
|
model = components['model']
|
||||||
|
metric = components['metric']
|
||||||
|
criterion = components['criterion']
|
||||||
|
optimizer = components['optimizer']
|
||||||
|
scheduler = components['scheduler']
|
||||||
|
checkpoints = components['checkpoints']
|
||||||
|
scaler = components['scaler']
|
||||||
|
device = components['device']
|
||||||
|
|
||||||
|
# 训练状态
|
||||||
train_losses = []
|
train_losses = []
|
||||||
val_losses = []
|
val_losses = []
|
||||||
epochs = []
|
epochs = []
|
||||||
temp_loss = 100
|
temp_loss = 100
|
||||||
|
|
||||||
if conf['training']['restore']:
|
if conf['training']['restore']:
|
||||||
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
print('load pretrain model: {}'.format(conf['training']['restore_model']))
|
||||||
model.load_state_dict(torch.load(conf['training']['restore_model'],
|
model.load_state_dict(torch.load(conf['training']['restore_model'], map_location=device))
|
||||||
map_location=conf['base']['device']))
|
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
for e in range(conf['training']['epochs']):
|
for e in range(conf['training']['epochs']):
|
||||||
train_loss = 0
|
train_loss_avg = train_one_epoch(model, metric, criterion, optimizer, train_dataloader, device, scaler, conf)
|
||||||
model.train()
|
train_losses.append(train_loss_avg)
|
||||||
|
|
||||||
for train_data, train_labels in tqdm(train_dataloader,
|
|
||||||
desc="Epoch {}/{}"
|
|
||||||
.format(e, conf['training']['epochs']),
|
|
||||||
ascii=True,
|
|
||||||
total=len(train_dataloader)):
|
|
||||||
train_data = train_data.to(conf['base']['device'])
|
|
||||||
train_labels = train_labels.to(conf['base']['device'])
|
|
||||||
|
|
||||||
train_embeddings = model(train_data).to(conf['base']['device']) # [256,512]
|
|
||||||
# pdb.set_trace()
|
|
||||||
|
|
||||||
if not conf['training']['metric'] == 'softmax':
|
|
||||||
thetas = metric(train_embeddings, train_labels) # [256,357]
|
|
||||||
else:
|
|
||||||
thetas = metric(train_embeddings)
|
|
||||||
tloss = criterion(thetas, train_labels)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
tloss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
train_loss += tloss.item()
|
|
||||||
train_lossAvg = train_loss / len(train_dataloader)
|
|
||||||
train_losses.append(train_lossAvg)
|
|
||||||
epochs.append(e)
|
epochs.append(e)
|
||||||
val_loss = 0
|
|
||||||
model.eval()
|
val_loss_avg = validate(model, metric, criterion, val_dataloader, device, conf)
|
||||||
with torch.no_grad():
|
val_losses.append(val_loss_avg)
|
||||||
for val_data, val_labels in tqdm(val_dataloader, desc="val",
|
|
||||||
ascii=True, total=len(val_dataloader)):
|
if val_loss_avg < temp_loss:
|
||||||
val_data = val_data.to(conf['base']['device'])
|
save_model(model, osp.join(checkpoints, 'best.pth'), isinstance(model, nn.DataParallel))
|
||||||
val_labels = val_labels.to(conf['base']['device'])
|
temp_loss = val_loss_avg
|
||||||
val_embeddings = model(val_data).to(conf['base']['device'])
|
|
||||||
if not conf['training']['metric'] == 'softmax':
|
|
||||||
thetas = metric(val_embeddings, val_labels)
|
|
||||||
else:
|
|
||||||
thetas = metric(val_embeddings)
|
|
||||||
vloss = criterion(thetas, val_labels)
|
|
||||||
val_loss += vloss.item()
|
|
||||||
val_lossAvg = val_loss / len(val_dataloader)
|
|
||||||
val_losses.append(val_lossAvg)
|
|
||||||
if val_lossAvg < temp_loss:
|
|
||||||
if torch.cuda.device_count() > 1:
|
|
||||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
|
||||||
else:
|
|
||||||
torch.save(model.state_dict(), osp.join(checkpoints, 'best.pth'))
|
|
||||||
temp_loss = val_lossAvg
|
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
@ -129,19 +208,26 @@ if __name__ == '__main__':
|
|||||||
.format(datetime.now(),
|
.format(datetime.now(),
|
||||||
e,
|
e,
|
||||||
conf['training']['epochs'],
|
conf['training']['epochs'],
|
||||||
train_lossAvg,
|
train_loss_avg,
|
||||||
val_lossAvg,
|
val_loss_avg,
|
||||||
current_lr))
|
current_lr))
|
||||||
print(log_info)
|
print(log_info)
|
||||||
# 写入日志文件
|
log_training_info(osp.join(conf['logging']['logging_dir']), log_info)
|
||||||
with open(osp.join(conf['logging']['logging_dir']), 'a') as f:
|
|
||||||
f.write(log_info + '\n')
|
|
||||||
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
print("第%d个epoch的学习率:%f" % (e, current_lr))
|
||||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
|
||||||
torch.save(model.module.state_dict(), osp.join(checkpoints, 'last.pth'))
|
# 保存最终模型
|
||||||
else:
|
save_model(model, osp.join(checkpoints, 'last.pth'), isinstance(model, nn.DataParallel))
|
||||||
torch.save(model.state_dict(), osp.join(checkpoints, 'last.pth'))
|
|
||||||
plt.plot(epochs, train_losses, color='blue')
|
# 绘制损失曲线
|
||||||
plt.plot(epochs, val_losses, color='red')
|
plt.plot(epochs, train_losses, color='blue', label='Train Loss')
|
||||||
# plt.savefig('lossMobilenetv3.png')
|
plt.plot(epochs, val_losses, color='red', label='Validation Loss')
|
||||||
|
plt.legend()
|
||||||
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
plt.savefig('loss/mobilenetv3Large_2250_0316.png')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 初始化训练组件
|
||||||
|
components = initialize_training_components()
|
||||||
|
|
||||||
|
# 运行训练循环
|
||||||
|
run_training_loop(components)
|
Reference in New Issue
Block a user