Files
ieemoo-ai-contrast/tools/dataset.py
2025-08-06 17:03:28 +08:00

69 lines
3.2 KiB
Python

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms.functional as F
import torchvision.transforms as T
# from config import config as conf
import torch
def pad_to_square(img):
w, h = img.size
max_wh = max(w, h)
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
return F.pad(img, padding, fill=0, padding_mode='constant')
def get_transform(cfg):
train_transform = T.Compose([
T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
# T.RandomCrop(img_size * 4 // 5),
T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']),
T.RandomRotation(cfg['transform']['RandomRotation']),
T.ColorJitter(brightness=cfg['transform']['ColorJitter']),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
])
test_transform = T.Compose([
T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
])
return train_transform, test_transform
def load_data(training=True, cfg=None):
train_transform, test_transform = get_transform(cfg)
if training:
dataroot = cfg['data']['data_train_dir']
transform = train_transform
# transform.yml = conf.train_transform
batch_size = cfg['data']['train_batch_size']
else:
dataroot = cfg['data']['data_val_dir']
# transform.yml = conf.test_transform
transform = test_transform
batch_size = cfg['data']['val_batch_size']
data = ImageFolder(dataroot, transform=transform)
class_num = len(data.classes)
loader = DataLoader(data,
batch_size=batch_size,
shuffle=True,
pin_memory=cfg['base']['pin_memory'],
num_workers=cfg['data']['num_workers'],
drop_last=True)
return loader, class_num
# def load_gift_data(action):
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# return train_dataset, val_dataset, test_dataset