107 lines
4.3 KiB
Python
107 lines
4.3 KiB
Python
from torch.utils.data import DataLoader
|
||
from torchvision.datasets import ImageFolder
|
||
import torchvision.transforms.functional as F
|
||
import torchvision.transforms as T
|
||
# from config import config as conf
|
||
import torch
|
||
|
||
|
||
def pad_to_square(img):
|
||
w, h = img.size
|
||
max_wh = max(w, h)
|
||
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
||
return F.pad(img, padding, fill=0, padding_mode='constant')
|
||
|
||
|
||
def get_transform(cfg):
|
||
train_transform = T.Compose([
|
||
T.Lambda(pad_to_square), # 补边
|
||
T.ToTensor(),
|
||
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||
# T.RandomCrop(img_size * 4 // 5),
|
||
T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']),
|
||
T.RandomRotation(cfg['transform']['RandomRotation']),
|
||
T.ColorJitter(brightness=cfg['transform']['ColorJitter']),
|
||
T.ConvertImageDtype(torch.float32),
|
||
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||
])
|
||
test_transform = T.Compose([
|
||
T.Lambda(pad_to_square), # 补边
|
||
T.ToTensor(),
|
||
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
||
T.ConvertImageDtype(torch.float32),
|
||
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
||
])
|
||
return train_transform, test_transform
|
||
|
||
|
||
def load_data(training=True, cfg=None, return_dataset=False):
|
||
train_transform, test_transform = get_transform(cfg)
|
||
if training:
|
||
dataroot = cfg['data']['data_train_dir']
|
||
transform = train_transform
|
||
# transform.yml = conf.train_transform
|
||
batch_size = cfg['data']['train_batch_size']
|
||
else:
|
||
dataroot = cfg['data']['data_val_dir']
|
||
# transform.yml = conf.test_transform
|
||
transform = test_transform
|
||
batch_size = cfg['data']['val_batch_size']
|
||
|
||
data = ImageFolder(dataroot, transform=transform)
|
||
class_num = len(data.classes)
|
||
if return_dataset:
|
||
return data, class_num
|
||
else:
|
||
loader = DataLoader(data,
|
||
batch_size=batch_size,
|
||
shuffle=True if training else False,
|
||
pin_memory=cfg['base']['pin_memory'],
|
||
num_workers=cfg['data']['num_workers'],
|
||
drop_last=True)
|
||
return loader, class_num
|
||
|
||
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
||
"""
|
||
MultiEpochsDataLoader 类
|
||
通过重用工作进程来提高数据加载效率,避免每个epoch重新启动工作进程
|
||
"""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self._DataLoader__initialized = False
|
||
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
||
self._DataLoader__initialized = True
|
||
self.iterator = super().__iter__()
|
||
|
||
def __len__(self):
|
||
return len(self.batch_sampler.sampler)
|
||
|
||
def __iter__(self):
|
||
for i in range(len(self)):
|
||
yield next(self.iterator)
|
||
|
||
|
||
class _RepeatSampler(object):
|
||
"""
|
||
重复采样器,避免每个epoch重新创建迭代器
|
||
"""
|
||
|
||
def __init__(self, sampler):
|
||
self.sampler = sampler
|
||
|
||
def __iter__(self):
|
||
while True:
|
||
yield from iter(self.sampler)
|
||
# def load_gift_data(action):
|
||
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
|
||
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
|
||
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
||
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
||
# return train_dataset, val_dataset, test_dataset
|