Files
ieemoo-ai-searchv2/cirtorch/datasets/genericdataset.py
2022-11-22 15:32:06 +08:00

122 lines
3.9 KiB
Python
Executable File

import os
import pdb
import torch
import torch.utils.data as data
from cirtorch.datasets.datahelpers import default_loader, imresize
class ImagesFromList(data.Dataset):
"""A generic data loader that loads images from a list
(Based on ImageFolder from pytorch)
Args:
root (string): Root directory path.
images (list): Relative image paths as strings.
imsize (int, Default: None): Defines the maximum size of longer image side
bbxs (list): List of (x1,y1,x2,y2) tuples to crop the query images
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
loader (callable, optional): A function to load an image given its path.
Attributes:
images_fn (list): List of full image filename
"""
def __init__(self, root, images, imsize=None, bbxs=None, transform=None, loader=default_loader):
images_fn = [os.path.join(root,images[i]) for i in range(len(images))]
if len(images_fn) == 0:
raise(RuntimeError("Dataset contains 0 images!"))
self.root = root
self.images = images
self.imsize = imsize
self.images_fn = images_fn
self.bbxs = bbxs
self.transform = transform
self.loader = loader
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
image (PIL): Loaded image
"""
path = self.images_fn[index]
img = self.loader(path)
imfullsize = max(img.size)
if self.bbxs is not None:
print('self.bbxs>>>ok')
img = img.crop(self.bbxs[index])
if self.imsize is not None:
if self.bbxs is not None:
print('self.bbxs and self.imsize>>>ok')
img = imresize(img, self.imsize * max(img.size) / imfullsize)
else:
print('not self.bbxs and self.imsize>>>ok')
img = imresize(img, self.imsize)
if self.transform is not None:
print('self.transform>>>>>ok')
img = self.transform(img)
return img, path
def __len__(self):
return len(self.images_fn)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of images: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
class ImagesFromDataList(data.Dataset):
"""A generic data loader that loads images given as an array of pytorch tensors
(Based on ImageFolder from pytorch)
Args:
images (list): Images as tensors.
transform (callable, optional): A function/transform that image as a tensors
and returns a transformed version. E.g, ``normalize`` with mean and std
"""
def __init__(self, images, transform=None):
if len(images) == 0:
raise(RuntimeError("Dataset contains 0 images!"))
self.images = images
self.transform = transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
image (Tensor): Loaded image
"""
img = self.images[index]
if self.transform is not None:
img = self.transform(img)
if len(img.size()):
img = img.unsqueeze(0)
return img
def __len__(self):
return len(self.images)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of images: {}\n'.format(self.__len__())
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str