21 lines
687 B
Python
21 lines
687 B
Python
from torch.utils.data import DataLoader
|
|
from torchvision.datasets import ImageFolder
|
|
|
|
from config import config as conf
|
|
|
|
|
|
def load_data(conf, training=True):
|
|
if training:
|
|
dataroot = conf.train_root
|
|
transform = conf.train_transform
|
|
batch_size = conf.train_batch_size
|
|
else:
|
|
dataroot = conf.test_root
|
|
transform = conf.test_transform
|
|
batch_size = conf.test_batch_size
|
|
|
|
data = ImageFolder(dataroot, transform=transform)
|
|
class_num = len(data.classes)
|
|
loader = DataLoader(data, batch_size=batch_size, shuffle=True,
|
|
pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
|
|
return loader, class_num |