from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder import torchvision.transforms.functional as F import torchvision.transforms as T # from config import config as conf import torch def pad_to_square(img): w, h = img.size max_wh = max(w, h) padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom) return F.pad(img, padding, fill=0, padding_mode='constant') def get_transform(cfg): train_transform = T.Compose([ T.Lambda(pad_to_square), # 补边 T.ToTensor(), T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True), # T.RandomCrop(img_size * 4 // 5), T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']), T.RandomRotation(cfg['transform']['RandomRotation']), T.ColorJitter(brightness=cfg['transform']['ColorJitter']), T.ConvertImageDtype(torch.float32), T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]), ]) test_transform = T.Compose([ T.Lambda(pad_to_square), # 补边 T.ToTensor(), T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True), T.ConvertImageDtype(torch.float32), T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]), ]) return train_transform, test_transform def load_data(training=True, cfg=None): train_transform, test_transform = get_transform(cfg) if training: dataroot = cfg['data']['data_train_dir'] transform = train_transform # transform.yml = conf.train_transform batch_size = cfg['data']['train_batch_size'] else: dataroot = cfg['data']['data_val_dir'] # transform.yml = conf.test_transform transform = test_transform batch_size = cfg['data']['val_batch_size'] data = ImageFolder(dataroot, transform=transform) class_num = len(data.classes) loader = DataLoader(data, batch_size=batch_size, shuffle=True, pin_memory=cfg['base']['pin_memory'], num_workers=cfg['data']['num_workers'], drop_last=True) return loader, class_num # def load_gift_data(action): # train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform) # train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True, # pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True) # val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform) # val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True, # pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True) # test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform) # test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True, # pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True) # return train_dataset, val_dataset, test_dataset