修改Dataloader提升训练效率
This commit is contained in:
@ -15,7 +15,7 @@ base:
|
|||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet50'
|
||||||
channel_ratio: 1.0
|
channel_ratio: 1.0
|
||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
@ -31,7 +31,7 @@ training:
|
|||||||
weight_decay: 0.0005 # 权重衰减
|
weight_decay: 0.0005 # 权重衰减
|
||||||
scheduler: "step" # 学习率调度器(可选:cosine/cosine_warm/step/None)
|
scheduler: "step" # 学习率调度器(可选:cosine/cosine_warm/step/None)
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
checkpoints: "./checkpoints/resnet18_electornic_20250806/" # 模型保存目录
|
checkpoints: "./checkpoints/resnet50_electornic_20250807/" # 模型保存目录
|
||||||
restore: false
|
restore: false
|
||||||
restore_model: "./checkpoints/resnet18_20250717_scale=0.75_nosub/best.pth" # 模型恢复路径
|
restore_model: "./checkpoints/resnet18_20250717_scale=0.75_nosub/best.pth" # 模型恢复路径
|
||||||
cosine_t_0: 10 # 初始周期长度
|
cosine_t_0: 10 # 初始周期长度
|
||||||
@ -62,7 +62,7 @@ transform:
|
|||||||
|
|
||||||
# 日志与监控
|
# 日志与监控
|
||||||
logging:
|
logging:
|
||||||
logging_dir: "./logs/resnet18_scale=0.75_nosub_log" # 日志保存目录
|
logging_dir: "./logs/resnet50_electornic_log" # 日志保存目录
|
||||||
tensorboard: true # 是否启用TensorBoard
|
tensorboard: true # 是否启用TensorBoard
|
||||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
@ -5,12 +5,14 @@ import torchvision.transforms as T
|
|||||||
# from config import config as conf
|
# from config import config as conf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def pad_to_square(img):
|
def pad_to_square(img):
|
||||||
w, h = img.size
|
w, h = img.size
|
||||||
max_wh = max(w, h)
|
max_wh = max(w, h)
|
||||||
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||||
|
|
||||||
|
|
||||||
def get_transform(cfg):
|
def get_transform(cfg):
|
||||||
train_transform = T.Compose([
|
train_transform = T.Compose([
|
||||||
T.Lambda(pad_to_square), # 补边
|
T.Lambda(pad_to_square), # 补边
|
||||||
@ -32,7 +34,8 @@ def get_transform(cfg):
|
|||||||
])
|
])
|
||||||
return train_transform, test_transform
|
return train_transform, test_transform
|
||||||
|
|
||||||
def load_data(training=True, cfg=None):
|
|
||||||
|
def load_data(training=True, cfg=None, return_dataset=False):
|
||||||
train_transform, test_transform = get_transform(cfg)
|
train_transform, test_transform = get_transform(cfg)
|
||||||
if training:
|
if training:
|
||||||
dataroot = cfg['data']['data_train_dir']
|
dataroot = cfg['data']['data_train_dir']
|
||||||
@ -47,14 +50,49 @@ def load_data(training=True, cfg=None):
|
|||||||
|
|
||||||
data = ImageFolder(dataroot, transform=transform)
|
data = ImageFolder(dataroot, transform=transform)
|
||||||
class_num = len(data.classes)
|
class_num = len(data.classes)
|
||||||
loader = DataLoader(data,
|
if return_dataset:
|
||||||
batch_size=batch_size,
|
return data, class_num
|
||||||
shuffle=True,
|
else:
|
||||||
pin_memory=cfg['base']['pin_memory'],
|
loader = DataLoader(data,
|
||||||
num_workers=cfg['data']['num_workers'],
|
batch_size=batch_size,
|
||||||
drop_last=True)
|
shuffle=True if training else False,
|
||||||
return loader, class_num
|
pin_memory=cfg['base']['pin_memory'],
|
||||||
|
num_workers=cfg['data']['num_workers'],
|
||||||
|
drop_last=True)
|
||||||
|
return loader, class_num
|
||||||
|
|
||||||
|
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
||||||
|
"""
|
||||||
|
MultiEpochsDataLoader 类
|
||||||
|
通过重用工作进程来提高数据加载效率,避免每个epoch重新启动工作进程
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._DataLoader__initialized = False
|
||||||
|
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
||||||
|
self._DataLoader__initialized = True
|
||||||
|
self.iterator = super().__iter__()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batch_sampler.sampler)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i in range(len(self)):
|
||||||
|
yield next(self.iterator)
|
||||||
|
|
||||||
|
|
||||||
|
class _RepeatSampler(object):
|
||||||
|
"""
|
||||||
|
重复采样器,避免每个epoch重新创建迭代器
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampler):
|
||||||
|
self.sampler = sampler
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
while True:
|
||||||
|
yield from iter(self.sampler)
|
||||||
# def load_gift_data(action):
|
# def load_gift_data(action):
|
||||||
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
|
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
|
||||||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||||
|
@ -10,7 +10,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from model.loss import FocalLoss
|
from model.loss import FocalLoss
|
||||||
from tools.dataset import load_data
|
from tools.dataset import load_data, MultiEpochsDataLoader
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from configs import trainer_tools
|
from configs import trainer_tools
|
||||||
import yaml
|
import yaml
|
||||||
@ -52,7 +52,7 @@ def setup_optimizer_and_scheduler(conf, model, metric):
|
|||||||
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
scheduler_mapping = tr_tools.get_scheduler(optimizer)
|
||||||
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
scheduler = scheduler_mapping[conf['training']['scheduler']]()
|
||||||
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
print('使用{}优化器 使用{}调度器'.format(conf['training']['optimizer'],
|
||||||
conf['training']['scheduler']))
|
conf['training']['scheduler']))
|
||||||
return optimizer, scheduler
|
return optimizer, scheduler
|
||||||
else:
|
else:
|
||||||
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
raise ValueError('不支持的优化器类型: {}'.format(conf['training']['optimizer']))
|
||||||
@ -146,9 +146,21 @@ def initialize_training_components(distributed=False):
|
|||||||
# 如果是非分布式训练,直接创建所有组件
|
# 如果是非分布式训练,直接创建所有组件
|
||||||
if not distributed:
|
if not distributed:
|
||||||
# 数据加载
|
# 数据加载
|
||||||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
train_dataloader, class_num = load_data(training=True, cfg=conf, return_dataset=True)
|
||||||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
val_dataloader, _ = load_data(training=False, cfg=conf, return_dataset=True)
|
||||||
|
|
||||||
|
train_dataloader = MultiEpochsDataLoader(train_dataloader,
|
||||||
|
batch_size=conf['data']['train_batch_size'],
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=conf['data']['num_workers'],
|
||||||
|
pin_memory=conf['base']['pin_memory'],
|
||||||
|
drop_last=True)
|
||||||
|
val_dataloader = MultiEpochsDataLoader(val_dataloader,
|
||||||
|
batch_size=conf['data']['val_batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=conf['data']['num_workers'],
|
||||||
|
pin_memory=conf['base']['pin_memory'],
|
||||||
|
drop_last=False)
|
||||||
# 初始化模型和度量
|
# 初始化模型和度量
|
||||||
model, metric = initialize_model_and_metric(conf, class_num)
|
model, metric = initialize_model_and_metric(conf, class_num)
|
||||||
device = conf['base']['device']
|
device = conf['base']['device']
|
||||||
@ -248,10 +260,10 @@ def main():
|
|||||||
"""主函数入口"""
|
"""主函数入口"""
|
||||||
# 加载配置
|
# 加载配置
|
||||||
conf = load_configuration()
|
conf = load_configuration()
|
||||||
|
|
||||||
# 检查是否启用分布式训练
|
# 检查是否启用分布式训练
|
||||||
distributed = conf['base']['distributed']
|
distributed = conf['base']['distributed']
|
||||||
|
|
||||||
if distributed:
|
if distributed:
|
||||||
# 分布式训练:使用mp.spawn启动多个进程
|
# 分布式训练:使用mp.spawn启动多个进程
|
||||||
world_size = torch.cuda.device_count()
|
world_size = torch.cuda.device_count()
|
||||||
@ -274,56 +286,56 @@ def run_training(rank, world_size, conf):
|
|||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
os.environ['WORLD_SIZE'] = str(world_size)
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
os.environ['MASTER_PORT'] = '12355'
|
os.environ['MASTER_PORT'] = '12355'
|
||||||
|
|
||||||
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
device = torch.device('cuda', rank)
|
device = torch.device('cuda', rank)
|
||||||
|
|
||||||
# 创建数据加载器和模型等组件(分布式情况下)
|
# 获取数据集而不是DataLoader
|
||||||
train_dataloader, class_num = load_data(training=True, cfg=conf)
|
train_dataset, class_num = load_data(training=True, cfg=conf, return_dataset=True)
|
||||||
val_dataloader, _ = load_data(training=False, cfg=conf)
|
val_dataset, _ = load_data(training=False, cfg=conf, return_dataset=True)
|
||||||
|
|
||||||
# 初始化模型和度量
|
# 初始化模型和度量
|
||||||
model, metric = initialize_model_and_metric(conf, class_num)
|
model, metric = initialize_model_and_metric(conf, class_num)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
metric = metric.to(device)
|
metric = metric.to(device)
|
||||||
|
|
||||||
# 包装为DistributedDataParallel模型
|
# 包装为DistributedDataParallel模型
|
||||||
model = DDP(model, device_ids=[rank], output_device=rank)
|
model = DDP(model, device_ids=[rank], output_device=rank)
|
||||||
metric = DDP(metric, device_ids=[rank], output_device=rank)
|
metric = DDP(metric, device_ids=[rank], output_device=rank)
|
||||||
|
|
||||||
# 设置损失函数、优化器和调度器
|
# 设置损失函数、优化器和调度器
|
||||||
criterion = setup_loss_function(conf)
|
criterion = setup_loss_function(conf)
|
||||||
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
optimizer, scheduler = setup_optimizer_and_scheduler(conf, model, metric)
|
||||||
|
|
||||||
# 检查点目录
|
# 检查点目录
|
||||||
checkpoints = conf['training']['checkpoints']
|
checkpoints = conf['training']['checkpoints']
|
||||||
os.makedirs(checkpoints, exist_ok=True)
|
os.makedirs(checkpoints, exist_ok=True)
|
||||||
|
|
||||||
# GradScaler for mixed precision
|
# GradScaler for mixed precision
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
# 创建分布式数据加载器
|
# 创建分布式采样器
|
||||||
train_sampler = DistributedSampler(train_dataloader.dataset, shuffle=True)
|
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
val_sampler = DistributedSampler(val_dataloader.dataset, shuffle=False)
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
|
|
||||||
# 重新创建适合分布式训练的数据加载器
|
# 使用 MultiEpochsDataLoader 创建分布式数据加载器
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = MultiEpochsDataLoader(
|
||||||
train_dataloader.dataset,
|
train_dataset,
|
||||||
batch_size=train_dataloader.batch_size,
|
batch_size=conf['data']['train_batch_size'],
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
num_workers=train_dataloader.num_workers,
|
num_workers=conf['data']['num_workers'],
|
||||||
pin_memory=train_dataloader.pin_memory,
|
pin_memory=conf['base']['pin_memory'],
|
||||||
drop_last=train_dataloader.drop_last
|
drop_last=True
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = torch.utils.data.DataLoader(
|
val_dataloader = MultiEpochsDataLoader(
|
||||||
val_dataloader.dataset,
|
val_dataset,
|
||||||
batch_size=val_dataloader.batch_size,
|
batch_size=conf['data']['val_batch_size'],
|
||||||
sampler=val_sampler,
|
sampler=val_sampler,
|
||||||
num_workers=val_dataloader.num_workers,
|
num_workers=conf['data']['num_workers'],
|
||||||
pin_memory=val_dataloader.pin_memory,
|
pin_memory=conf['base']['pin_memory'],
|
||||||
drop_last=val_dataloader.drop_last
|
drop_last=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建组件字典
|
# 构建组件字典
|
||||||
@ -341,7 +353,7 @@ def run_training(rank, world_size, conf):
|
|||||||
'device': device,
|
'device': device,
|
||||||
'distributed': True # 因为是在mp.spawn中运行
|
'distributed': True # 因为是在mp.spawn中运行
|
||||||
}
|
}
|
||||||
|
|
||||||
# 运行训练循环
|
# 运行训练循环
|
||||||
run_training_loop(components)
|
run_training_loop(components)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user