122 lines
3.9 KiB
Python
Executable File
122 lines
3.9 KiB
Python
Executable File
import os
|
|
import pdb
|
|
|
|
import torch
|
|
import torch.utils.data as data
|
|
|
|
from cirtorch.datasets.datahelpers import default_loader, imresize
|
|
|
|
|
|
class ImagesFromList(data.Dataset):
|
|
"""A generic data loader that loads images from a list
|
|
(Based on ImageFolder from pytorch)
|
|
Args:
|
|
root (string): Root directory path.
|
|
images (list): Relative image paths as strings.
|
|
imsize (int, Default: None): Defines the maximum size of longer image side
|
|
bbxs (list): List of (x1,y1,x2,y2) tuples to crop the query images
|
|
transform (callable, optional): A function/transform that takes in an PIL image
|
|
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
|
loader (callable, optional): A function to load an image given its path.
|
|
Attributes:
|
|
images_fn (list): List of full image filename
|
|
"""
|
|
|
|
def __init__(self, root, images, imsize=None, bbxs=None, transform=None, loader=default_loader):
|
|
|
|
images_fn = [os.path.join(root,images[i]) for i in range(len(images))]
|
|
|
|
if len(images_fn) == 0:
|
|
raise(RuntimeError("Dataset contains 0 images!"))
|
|
|
|
self.root = root
|
|
self.images = images
|
|
self.imsize = imsize
|
|
self.images_fn = images_fn
|
|
self.bbxs = bbxs
|
|
self.transform = transform
|
|
self.loader = loader
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Args:
|
|
index (int): Index
|
|
Returns:
|
|
image (PIL): Loaded image
|
|
"""
|
|
path = self.images_fn[index]
|
|
img = self.loader(path)
|
|
imfullsize = max(img.size)
|
|
|
|
if self.bbxs is not None:
|
|
print('self.bbxs>>>ok')
|
|
img = img.crop(self.bbxs[index])
|
|
|
|
if self.imsize is not None:
|
|
if self.bbxs is not None:
|
|
print('self.bbxs and self.imsize>>>ok')
|
|
img = imresize(img, self.imsize * max(img.size) / imfullsize)
|
|
else:
|
|
print('not self.bbxs and self.imsize>>>ok')
|
|
img = imresize(img, self.imsize)
|
|
|
|
if self.transform is not None:
|
|
print('self.transform>>>>>ok')
|
|
img = self.transform(img)
|
|
|
|
return img, path
|
|
|
|
def __len__(self):
|
|
return len(self.images_fn)
|
|
|
|
def __repr__(self):
|
|
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
|
fmt_str += ' Number of images: {}\n'.format(self.__len__())
|
|
fmt_str += ' Root Location: {}\n'.format(self.root)
|
|
tmp = ' Transforms (if any): '
|
|
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
|
return fmt_str
|
|
|
|
class ImagesFromDataList(data.Dataset):
|
|
"""A generic data loader that loads images given as an array of pytorch tensors
|
|
(Based on ImageFolder from pytorch)
|
|
Args:
|
|
images (list): Images as tensors.
|
|
transform (callable, optional): A function/transform that image as a tensors
|
|
and returns a transformed version. E.g, ``normalize`` with mean and std
|
|
"""
|
|
|
|
def __init__(self, images, transform=None):
|
|
|
|
if len(images) == 0:
|
|
raise(RuntimeError("Dataset contains 0 images!"))
|
|
|
|
self.images = images
|
|
self.transform = transform
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Args:
|
|
index (int): Index
|
|
Returns:
|
|
image (Tensor): Loaded image
|
|
"""
|
|
img = self.images[index]
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
|
|
if len(img.size()):
|
|
img = img.unsqueeze(0)
|
|
|
|
return img
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
def __repr__(self):
|
|
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
|
fmt_str += ' Number of images: {}\n'.format(self.__len__())
|
|
tmp = ' Transforms (if any): '
|
|
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
|
return fmt_str
|