Files
ieemoo-ai-contrast/tools/dataset.py
2025-08-07 10:52:42 +08:00

107 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms.functional as F
import torchvision.transforms as T
# from config import config as conf
import torch
def pad_to_square(img):
w, h = img.size
max_wh = max(w, h)
padding = [(max_wh - w) // 2, (max_wh - h) // 2, (max_wh - w) // 2, (max_wh - h) // 2] # (left, top, right, bottom)
return F.pad(img, padding, fill=0, padding_mode='constant')
def get_transform(cfg):
train_transform = T.Compose([
T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
# T.RandomCrop(img_size * 4 // 5),
T.RandomHorizontalFlip(p=cfg['transform']['RandomHorizontalFlip']),
T.RandomRotation(cfg['transform']['RandomRotation']),
T.ColorJitter(brightness=cfg['transform']['ColorJitter']),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
])
test_transform = T.Compose([
T.Lambda(pad_to_square), # 补边
T.ToTensor(),
T.Resize((cfg['transform']['img_size'], cfg['transform']['img_size']), antialias=True),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=[cfg['transform']['img_mean']], std=[cfg['transform']['img_std']]),
])
return train_transform, test_transform
def load_data(training=True, cfg=None, return_dataset=False):
train_transform, test_transform = get_transform(cfg)
if training:
dataroot = cfg['data']['data_train_dir']
transform = train_transform
# transform.yml = conf.train_transform
batch_size = cfg['data']['train_batch_size']
else:
dataroot = cfg['data']['data_val_dir']
# transform.yml = conf.test_transform
transform = test_transform
batch_size = cfg['data']['val_batch_size']
data = ImageFolder(dataroot, transform=transform)
class_num = len(data.classes)
if return_dataset:
return data, class_num
else:
loader = DataLoader(data,
batch_size=batch_size,
shuffle=True if training else False,
pin_memory=cfg['base']['pin_memory'],
num_workers=cfg['data']['num_workers'],
drop_last=True)
return loader, class_num
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
"""
MultiEpochsDataLoader 类
通过重用工作进程来提高数据加载效率避免每个epoch重新启动工作进程
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
"""
重复采样器避免每个epoch重新创建迭代器
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
# def load_gift_data(action):
# train_data = ImageFolder(conf.train_gift_root, transform.yml=conf.train_transform)
# train_dataset = DataLoader(train_data, batch_size=conf.train_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# val_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# val_dataset = DataLoader(val_data, batch_size=conf.val_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# test_data = ImageFolder(conf.test_gift_root, transform.yml=conf.test_transform)
# test_dataset = DataLoader(test_data, batch_size=conf.test_gift_batchsize, shuffle=True,
# pin_memory=conf.pin_memory, num_workers=conf.num_workers, drop_last=True)
# return train_dataset, val_dataset, test_dataset