修改Dataloader提升训练效率

This commit is contained in:
lee
2025-08-07 10:52:42 +08:00
parent 3392d76e38
commit ebba07d1ca
3 changed files with 98 additions and 48 deletions

View File

@ -10,7 +10,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from model.loss import FocalLoss
from tools.dataset import load_data
from tools.dataset import load_data, MultiEpochsDataLoader
import matplotlib.pyplot as plt
from configs import trainer_tools
import yaml
@ -52,7 +52,7 @@ def setup_optimizer_and_scheduler(conf, model, metric):
scheduler_mapping = tr_tools.get_scheduler(optimizer)
scheduler = scheduler_mapping[conf['training']['scheduler']]()
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
conf['training']['scheduler']))
conf['training']['scheduler']))
return optimizer, scheduler
else:
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
@ -146,9 +146,21 @@ def initialize_training_components(distributed=False):
# 如果是非分布式训练,直接创建所有组件
if not distributed:
# 数据加载
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
train_dataloader, class_num = load_data(training=True, cfg=conf, return_dataset=True)
val_dataloader, _ = load_data(training=False, cfg=conf, return_dataset=True)
train_dataloader = MultiEpochsDataLoader(train_dataloader,
batch_size=conf['data']['train_batch_size'],
shuffle=True,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=True)
val_dataloader = MultiEpochsDataLoader(val_dataloader,
batch_size=conf['data']['val_batch_size'],
shuffle=False,
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=False)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
device = conf['base']['device']
@ -248,10 +260,10 @@ def main():
"""主函数入口"""
# 加载配置
conf = load_configuration()
# 检查是否启用分布式训练
distributed = conf['base']['distributed']
if distributed:
# 分布式训练使用mp.spawn启动多个进程
world_size = torch.cuda.device_count()
@ -274,56 +286,56 @@ def run_training(rank, world_size, conf):
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device('cuda', rank)
# 创建数据加载器和模型等组件(分布式情况下)
train_dataloader, class_num = load_data(training=True, cfg=conf)
val_dataloader, _ = load_data(training=False, cfg=conf)
# 获取数据集而不是DataLoader
train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
val_dataset, _ = load_data(training=False, cfg=conf, return_dataset=True)
# 初始化模型和度量
model, metric = initialize_model_and_metric(conf, class_num)
model = model.to(device)
metric = metric.to(device)
# 包装为DistributedDataParallel模型
model = DDP(model, device_ids=[rank], output_device=rank)
metric = DDP(metric, device_ids=[rank], output_device=rank)
# 设置损失函数、优化器和调度器
criterion = setup_loss_function(conf)
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
# 检查点目录
checkpoints = conf['training']['checkpoints']
os.makedirs(checkpoints, exist_ok=True)
# GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler()
# 创建分布式数据加载
train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True)
val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False)
# 重新创建适合分布式训练的数据加载器
train_dataloader = torch.utils.data.DataLoader(
train_dataloader.dataset,
batch_size=train_dataloader.batch_size,
# 创建分布式采样
train_sampler = DistributedSampler(train_dataset, shuffle=True)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
# 使用 MultiEpochsDataLoader 创建分布式数据加载器
train_dataloader = MultiEpochsDataLoader(
train_dataset,
batch_size=conf['data']['train_batch_size'],
sampler=train_sampler,
num_workers=train_dataloader.num_workers,
pin_memory=train_dataloader.pin_memory,
drop_last=train_dataloader.drop_last
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=True
)
val_dataloader = torch.utils.data.DataLoader(
val_dataloader.dataset,
batch_size=val_dataloader.batch_size,
val_dataloader = MultiEpochsDataLoader(
val_dataset,
batch_size=conf['data']['val_batch_size'],
sampler=val_sampler,
num_workers=val_dataloader.num_workers,
pin_memory=val_dataloader.pin_memory,
drop_last=val_dataloader.drop_last
num_workers=conf['data']['num_workers'],
pin_memory=conf['base']['pin_memory'],
drop_last=False
)
# 构建组件字典
@ -341,7 +353,7 @@ def run_training(rank, world_size, conf):
'device': device,
'distributed': True # 因为是在mp.spawn中运行
}
# 运行训练循环
run_training_loop(components)