Files
ieemoo-ai-imageassessment/contrast/dataset.py
2024-11-27 15:37:10 +08:00

21 lines
687 B
Python

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from config import config as conf
def load_data(conf, training=True):
if training:
dataroot = conf.train_root
transform = conf.train_transform
batch_size = conf.train_batch_size
else:
dataroot = conf.test_root
transform = conf.test_transform
batch_size = conf.test_batch_size
data = ImageFolder(dataroot, transform=transform)
class_num = len(data.classes)
loader = DataLoader(data, batch_size=batch_size, shuffle=True,
pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
return loader, class_num