修改Dataloader提升训练效率
This commit is contained in:
@ -5,12 +5,14 @@ import torchvision.transforms as T
|
||||
# from config import config as conf
|
||||
import torch
|
||||
|
||||
|
||||
def pad_to_square(img):
|
||||
w, h = img.size
|
||||
max_wh = max(w, h)
|
||||
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||||
|
||||
|
||||
def get_transform(cfg):
|
||||
train_transform = T.Compose([
|
||||
T.Lambda(pad_to_square), # 补边
|
||||
@ -32,7 +34,8 @@ def get_transform(cfg):
|
||||
])
|
||||
return train_transform, test_transform
|
||||
|
||||
def load_data(training=True, cfg=None):
|
||||
|
||||
def load_data(training=True, cfg=None, return_dataset=False):
|
||||
train_transform, test_transform = get_transform(cfg)
|
||||
if training:
|
||||
dataroot = cfg['data']['data_train_dir']
|
||||
@ -47,14 +50,49 @@ def load_data(training=True, cfg=None):
|
||||
|
||||
data = ImageFolder(dataroot, transform=transform)
|
||||
class_num = len(data.classes)
|
||||
loader = DataLoader(data,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg['base']['pin_memory'],
|
||||
num_workers=cfg['data']['num_workers'],
|
||||
drop_last=True)
|
||||
return loader, class_num
|
||||
if return_dataset:
|
||||
return data, class_num
|
||||
else:
|
||||
loader = DataLoader(data,
|
||||
batch_size=batch_size,
|
||||
shuffle=True if training else False,
|
||||
pin_memory=cfg['base']['pin_memory'],
|
||||
num_workers=cfg['data']['num_workers'],
|
||||
drop_last=True)
|
||||
return loader, class_num
|
||||
|
||||
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
||||
"""
|
||||
MultiEpochsDataLoader 类
|
||||
通过重用工作进程来提高数据加载效率,避免每个epoch重新启动工作进程
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._DataLoader__initialized = False
|
||||
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
||||
self._DataLoader__initialized = True
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_sampler.sampler)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
|
||||
class _RepeatSampler(object):
|
||||
"""
|
||||
重复采样器,避免每个epoch重新创建迭代器
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield from iter(self.sampler)
|
||||
# def load_gift_data(action):
|
||||
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
|
||||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||||
|
Reference in New Issue
Block a user