69 lines
3.2 KiB
Python
69 lines
3.2 KiB
Python
from torch.utils.data import DataLoader
|
|
from torchvision.datasets import ImageFolder
|
|
import torchvision.transforms.functional as F
|
|
import torchvision.transforms as T
|
|
# from config import config as conf
|
|
import torch
|
|
|
|
def pad_to_square(img):
|
|
w, h = img.size
|
|
max_wh = max(w, h)
|
|
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
|
|
return F.pad(img, padding, fill=0, padding_mode='constant')
|
|
|
|
def get_transform(cfg):
|
|
train_transform = T.Compose([
|
|
T.Lambda(pad_to_square), # 补边
|
|
T.ToTensor(),
|
|
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
|
# T.RandomCrop(img_size * 4 // 5),
|
|
T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']),
|
|
T.RandomRotation(cfg['transform']['RandomRotation']),
|
|
T.ColorJitter(brightness=cfg['transform']['ColorJitter']),
|
|
T.ConvertImageDtype(torch.float32),
|
|
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
|
])
|
|
test_transform = T.Compose([
|
|
# T.Lambda(pad_to_square), # 补边
|
|
T.ToTensor(),
|
|
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
|
|
T.ConvertImageDtype(torch.float32),
|
|
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
|
|
])
|
|
return train_transform, test_transform
|
|
|
|
def load_data(training=True, cfg=None):
|
|
train_transform, test_transform = get_transform(cfg)
|
|
if training:
|
|
dataroot = cfg['data']['data_train_dir']
|
|
transform = train_transform
|
|
# transform = conf.train_transform
|
|
batch_size = cfg['data']['train_batch_size']
|
|
else:
|
|
dataroot = cfg['data']['data_val_dir']
|
|
# transform = conf.test_transform
|
|
transform = test_transform
|
|
batch_size = cfg['data']['val_batch_size']
|
|
|
|
data = ImageFolder(dataroot, transform=transform)
|
|
class_num = len(data.classes)
|
|
loader = DataLoader(data,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
pin_memory=cfg['base']['pin_memory'],
|
|
num_workers=cfg['data']['num_workers'],
|
|
drop_last=True)
|
|
return loader, class_num
|
|
|
|
# def load_gift_data(action):
|
|
# train_data = ImageFolder(conf.train_gift_root, transform=conf.train_transform)
|
|
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
|
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
|
# val_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
|
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
|
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
|
# test_data = ImageFolder(conf.test_gift_root, transform=conf.test_transform)
|
|
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
|
|
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
|
# return train_dataset, val_dataset, test_dataset
|