first push
This commit is contained in:
0
cirtorch/datasets/__init__.py
Executable file
0
cirtorch/datasets/__init__.py
Executable file
56
cirtorch/datasets/datahelpers.py
Executable file
56
cirtorch/datasets/datahelpers.py
Executable file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
|
||||
def cid2filename(cid, prefix):
|
||||
"""
|
||||
Creates a training image path out of its CID name
|
||||
|
||||
Arguments
|
||||
---------
|
||||
cid : name of the image
|
||||
prefix : root directory where images are saved
|
||||
|
||||
Returns
|
||||
-------
|
||||
filename : full image filename
|
||||
"""
|
||||
return os.path.join(prefix, cid[-2:], cid[-4:-2], cid[-6:-4], cid)
|
||||
|
||||
def pil_loader(path):
|
||||
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
return img.convert('RGB')
|
||||
|
||||
def accimage_loader(path):
|
||||
import accimage
|
||||
try:
|
||||
return accimage.Image(path)
|
||||
except IOError:
|
||||
# Potentially a decoding problem, fall back to PIL.Image
|
||||
return pil_loader(path)
|
||||
|
||||
def default_loader(path):
|
||||
from torchvision import get_image_backend
|
||||
if get_image_backend() == 'accimage':
|
||||
return accimage_loader(path)
|
||||
else:
|
||||
return pil_loader(path)
|
||||
|
||||
def imresize(img, imsize):
|
||||
img.thumbnail((imsize, imsize), Image.ANTIALIAS)
|
||||
return img
|
||||
|
||||
def flip(x, dim):
|
||||
xsize = x.size()
|
||||
dim = x.dim() + dim if dim < 0 else dim
|
||||
x = x.view(-1, *xsize[dim:])
|
||||
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
|
||||
return x.view(xsize)
|
||||
|
||||
def collate_tuples(batch):
|
||||
if len(batch) == 1:
|
||||
return [batch[0][0]], [batch[0][1]]
|
||||
return [batch[i][0] for i in range(len(batch))], [batch[i][1] for i in range(len(batch))]
|
121
cirtorch/datasets/genericdataset.py
Executable file
121
cirtorch/datasets/genericdataset.py
Executable file
@ -0,0 +1,121 @@
|
||||
import os
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from cirtorch.datasets.datahelpers import default_loader, imresize
|
||||
|
||||
|
||||
class ImagesFromList(data.Dataset):
|
||||
"""A generic data loader that loads images from a list
|
||||
(Based on ImageFolder from pytorch)
|
||||
Args:
|
||||
root (string): Root directory path.
|
||||
images (list): Relative image paths as strings.
|
||||
imsize (int, Default: None): Defines the maximum size of longer image side
|
||||
bbxs (list): List of (x1,y1,x2,y2) tuples to crop the query images
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
loader (callable, optional): A function to load an image given its path.
|
||||
Attributes:
|
||||
images_fn (list): List of full image filename
|
||||
"""
|
||||
|
||||
def __init__(self, root, images, imsize=None, bbxs=None, transform=None, loader=default_loader):
|
||||
|
||||
images_fn = [os.path.join(root,images[i]) for i in range(len(images))]
|
||||
|
||||
if len(images_fn) == 0:
|
||||
raise(RuntimeError("Dataset contains 0 images!"))
|
||||
|
||||
self.root = root
|
||||
self.images = images
|
||||
self.imsize = imsize
|
||||
self.images_fn = images_fn
|
||||
self.bbxs = bbxs
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
image (PIL): Loaded image
|
||||
"""
|
||||
path = self.images_fn[index]
|
||||
img = self.loader(path)
|
||||
imfullsize = max(img.size)
|
||||
|
||||
if self.bbxs is not None:
|
||||
print('self.bbxs>>>ok')
|
||||
img = img.crop(self.bbxs[index])
|
||||
|
||||
if self.imsize is not None:
|
||||
if self.bbxs is not None:
|
||||
print('self.bbxs and self.imsize>>>ok')
|
||||
img = imresize(img, self.imsize * max(img.size) / imfullsize)
|
||||
else:
|
||||
print('not self.bbxs and self.imsize>>>ok')
|
||||
img = imresize(img, self.imsize)
|
||||
|
||||
if self.transform is not None:
|
||||
print('self.transform>>>>>ok')
|
||||
img = self.transform(img)
|
||||
|
||||
return img, path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_fn)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of images: {}\n'.format(self.__len__())
|
||||
fmt_str += ' Root Location: {}\n'.format(self.root)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
class ImagesFromDataList(data.Dataset):
|
||||
"""A generic data loader that loads images given as an array of pytorch tensors
|
||||
(Based on ImageFolder from pytorch)
|
||||
Args:
|
||||
images (list): Images as tensors.
|
||||
transform (callable, optional): A function/transform that image as a tensors
|
||||
and returns a transformed version. E.g, ``normalize`` with mean and std
|
||||
"""
|
||||
|
||||
def __init__(self, images, transform=None):
|
||||
|
||||
if len(images) == 0:
|
||||
raise(RuntimeError("Dataset contains 0 images!"))
|
||||
|
||||
self.images = images
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
image (Tensor): Loaded image
|
||||
"""
|
||||
img = self.images[index]
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if len(img.size()):
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of images: {}\n'.format(self.__len__())
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
38
cirtorch/datasets/testdataset.py
Executable file
38
cirtorch/datasets/testdataset.py
Executable file
@ -0,0 +1,38 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
DATASETS = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
|
||||
|
||||
def configdataset(dataset, dir_main):
|
||||
|
||||
dataset = dataset.lower()
|
||||
|
||||
if dataset not in DATASETS:
|
||||
raise ValueError('Unknown dataset: {}!'.format(dataset))
|
||||
|
||||
# loading imlist, qimlist, and gnd, in cfg as a dict
|
||||
gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset))
|
||||
with open(gnd_fname, 'rb') as f:
|
||||
cfg = pickle.load(f)
|
||||
cfg['gnd_fname'] = gnd_fname
|
||||
|
||||
cfg['ext'] = '.jpg'
|
||||
cfg['qext'] = '.jpg'
|
||||
cfg['dir_data'] = os.path.join(dir_main, dataset)
|
||||
cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg')
|
||||
|
||||
cfg['n'] = len(cfg['imlist'])
|
||||
cfg['nq'] = len(cfg['qimlist'])
|
||||
|
||||
cfg['im_fname'] = config_imname
|
||||
cfg['qim_fname'] = config_qimname
|
||||
|
||||
cfg['dataset'] = dataset
|
||||
|
||||
return cfg
|
||||
|
||||
def config_imname(cfg, i):
|
||||
return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext'])
|
||||
|
||||
def config_qimname(cfg, i):
|
||||
return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext'])
|
247
cirtorch/datasets/traindataset.py
Executable file
247
cirtorch/datasets/traindataset.py
Executable file
@ -0,0 +1,247 @@
|
||||
import os
|
||||
import pickle
|
||||
import pdb
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from cirtorch.datasets.datahelpers import default_loader, imresize, cid2filename
|
||||
from cirtorch.datasets.genericdataset import ImagesFromList
|
||||
from cirtorch.utils.general import get_data_root
|
||||
|
||||
class TuplesDataset(data.Dataset):
|
||||
"""Data loader that loads training and validation tuples of
|
||||
Radenovic etal ECCV16: CNN image retrieval learns from BoW
|
||||
|
||||
Args:
|
||||
name (string): dataset name: 'retrieval-sfm-120k'
|
||||
mode (string): 'train' or 'val' for training and validation parts of dataset
|
||||
imsize (int, Default: None): Defines the maximum size of longer image side
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
loader (callable, optional): A function to load an image given its path.
|
||||
nnum (int, Default:5): Number of negatives for a query image in a training tuple
|
||||
qsize (int, Default:1000): Number of query images, ie number of (q,p,n1,...nN) tuples, to be processed in one epoch
|
||||
poolsize (int, Default:10000): Pool size for negative images re-mining
|
||||
|
||||
Attributes:
|
||||
images (list): List of full filenames for each image
|
||||
clusters (list): List of clusterID per image
|
||||
qpool (list): List of all query image indexes
|
||||
ppool (list): List of positive image indexes, each corresponding to query at the same position in qpool
|
||||
|
||||
qidxs (list): List of qsize query image indexes to be processed in an epoch
|
||||
pidxs (list): List of qsize positive image indexes, each corresponding to query at the same position in qidxs
|
||||
nidxs (list): List of qsize tuples of negative images
|
||||
Each nidxs tuple contains nnum images corresponding to query image at the same position in qidxs
|
||||
|
||||
Lists qidxs, pidxs, nidxs are refreshed by calling the ``create_epoch_tuples()`` method,
|
||||
ie new q-p pairs are picked and negative images are remined
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, transform=None, loader=default_loader):
|
||||
|
||||
if not (mode == 'train' or mode == 'val'):
|
||||
raise(RuntimeError("MODE should be either train or val, passed as string"))
|
||||
|
||||
if name.startswith('retrieval-SfM'):
|
||||
# setting up paths
|
||||
#data_root = get_data_root()
|
||||
#db_root = os.path.join(data_root, 'train', name)
|
||||
#ims_root = os.path.join(db_root, 'ims')
|
||||
db_root = '/home/lc/project/Search_By_Image_Upgrade/cirtorch/IamgeRetrieval_dataset'
|
||||
ims_root = '/home/lc/project/Search_By_Image_Upgrade/cirtorch/IamgeRetrieval_dataset/train'
|
||||
# loading db
|
||||
db_fn = os.path.join(db_root, '{}.pkl'.format('train'))
|
||||
with open(db_fn, 'rb') as f:
|
||||
db = pickle.load(f)[mode]
|
||||
|
||||
# setting fullpath for images
|
||||
self.images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
|
||||
|
||||
#elif name.startswith('gl'):
|
||||
## TODO: NOT IMPLEMENTED YET PROPOERLY (WITH AUTOMATIC DOWNLOAD)
|
||||
|
||||
# setting up paths
|
||||
#db_root = '/mnt/fry2/users/datasets/landmarkscvprw18/recognition/'
|
||||
#ims_root = os.path.join(db_root, 'images', 'train')
|
||||
|
||||
# loading db
|
||||
#db_fn = os.path.join(db_root, '{}.pkl'.format('train'))
|
||||
#with open(db_fn, 'rb') as f:
|
||||
# db = pickle.load(f)[mode]
|
||||
|
||||
# setting fullpath for images
|
||||
self.images = [os.path.join(ims_root, db['cids'][i]) for i in range(len(db['cids']))]
|
||||
else:
|
||||
raise(RuntimeError("Unknown dataset name!"))
|
||||
|
||||
# initializing tuples dataset
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.imsize = imsize
|
||||
self.clusters = db['cluster']
|
||||
self.qpool = db['qidxs']
|
||||
self.ppool = db['pidxs']
|
||||
|
||||
## If we want to keep only unique q-p pairs
|
||||
## However, ordering of pairs will change, although that is not important
|
||||
# qpidxs = list(set([(self.qidxs[i], self.pidxs[i]) for i in range(len(self.qidxs))]))
|
||||
# self.qidxs = [qpidxs[i][0] for i in range(len(qpidxs))]
|
||||
# self.pidxs = [qpidxs[i][1] for i in range(len(qpidxs))]
|
||||
|
||||
# size of training subset for an epoch
|
||||
self.nnum = nnum
|
||||
self.qsize = min(qsize, len(self.qpool))
|
||||
self.poolsize = min(poolsize, len(self.images))
|
||||
self.qidxs = None
|
||||
self.pidxs = None
|
||||
self.nidxs = None
|
||||
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
self.print_freq = 10
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
images tuple (q,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs
|
||||
"""
|
||||
if self.__len__() == 0:
|
||||
raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!"))
|
||||
|
||||
output = []
|
||||
# query image
|
||||
output.append(self.loader(self.images[self.qidxs[index]]))
|
||||
# positive image
|
||||
output.append(self.loader(self.images[self.pidxs[index]]))
|
||||
# negative images
|
||||
for i in range(len(self.nidxs[index])):
|
||||
output.append(self.loader(self.images[self.nidxs[index][i]]))
|
||||
|
||||
if self.imsize is not None:
|
||||
output = [imresize(img, self.imsize) for img in output]
|
||||
|
||||
if self.transform is not None:
|
||||
output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))]
|
||||
|
||||
target = torch.Tensor([-1, 1] + [0]*len(self.nidxs[index]))
|
||||
|
||||
return output, target
|
||||
|
||||
def __len__(self):
|
||||
# if not self.qidxs:
|
||||
# return 0
|
||||
# return len(self.qidxs)
|
||||
return self.qsize
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Name and mode: {} {}\n'.format(self.name, self.mode)
|
||||
fmt_str += ' Number of images: {}\n'.format(len(self.images))
|
||||
fmt_str += ' Number of training tuples: {}\n'.format(len(self.qpool))
|
||||
fmt_str += ' Number of negatives per tuple: {}\n'.format(self.nnum)
|
||||
fmt_str += ' Number of tuples processed in an epoch: {}\n'.format(self.qsize)
|
||||
fmt_str += ' Pool size for negative remining: {}\n'.format(self.poolsize)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
def create_epoch_tuples(self, net):
|
||||
|
||||
print('>> Creating tuples for an epoch of {}-{}...'.format(self.name, self.mode))
|
||||
print(">>>> used network: ")
|
||||
print(net.meta_repr())
|
||||
|
||||
## ------------------------
|
||||
## SELECTING POSITIVE PAIRS
|
||||
## ------------------------
|
||||
|
||||
# draw qsize random queries for tuples
|
||||
idxs2qpool = torch.randperm(len(self.qpool))[:self.qsize]
|
||||
self.qidxs = [self.qpool[i] for i in idxs2qpool]
|
||||
self.pidxs = [self.ppool[i] for i in idxs2qpool]
|
||||
|
||||
## ------------------------
|
||||
## SELECTING NEGATIVE PAIRS
|
||||
## ------------------------
|
||||
|
||||
# if nnum = 0 create dummy nidxs
|
||||
# useful when only positives used for training
|
||||
if self.nnum == 0:
|
||||
self.nidxs = [[] for _ in range(len(self.qidxs))]
|
||||
return 0
|
||||
|
||||
# draw poolsize random images for pool of negatives images
|
||||
idxs2images = torch.randperm(len(self.images))[:self.poolsize]
|
||||
|
||||
# prepare network
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# no gradients computed, to reduce memory and increase speed
|
||||
with torch.no_grad():
|
||||
|
||||
print('>> Extracting descriptors for query images...')
|
||||
# prepare query loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=[self.images[i] for i in self.qidxs], imsize=self.imsize, transform=self.transform),
|
||||
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
|
||||
)
|
||||
# extract query vectors
|
||||
qvecs = torch.zeros(net.meta['outputdim'], len(self.qidxs)).cuda()
|
||||
for i, input in enumerate(loader):
|
||||
#print('*********************',input,type(input))
|
||||
#print('#######################',type(input))
|
||||
qvecs[:, i] = net(input[0].cuda()).data.squeeze()
|
||||
if (i+1) % self.print_freq == 0 or (i+1) == len(self.qidxs):
|
||||
print('\r>>>> {}/{} done...'.format(i+1, len(self.qidxs)), end='')
|
||||
print('')
|
||||
|
||||
print('>> Extracting descriptors for negative pool...')
|
||||
# prepare negative pool data loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=[self.images[i] for i in idxs2images], imsize=self.imsize, transform=self.transform),
|
||||
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
|
||||
)
|
||||
# extract negative pool vectors
|
||||
poolvecs = torch.zeros(net.meta['outputdim'], len(idxs2images)).cuda()
|
||||
for i, input in enumerate(loader):
|
||||
poolvecs[:, i] = net(input[0].cuda()).data.squeeze()
|
||||
if (i+1) % self.print_freq == 0 or (i+1) == len(idxs2images):
|
||||
print('\r>>>> {}/{} done...'.format(i+1, len(idxs2images)), end='')
|
||||
print('')
|
||||
|
||||
print('>> Searching for hard negatives...')
|
||||
# compute dot product scores and ranks on GPU
|
||||
scores = torch.mm(poolvecs.t(), qvecs)
|
||||
scores, ranks = torch.sort(scores, dim=0, descending=True)
|
||||
avg_ndist = torch.tensor(0).float().cuda() # for statistics
|
||||
n_ndist = torch.tensor(0).float().cuda() # for statistics
|
||||
# selection of negative examples
|
||||
self.nidxs = []
|
||||
for q in range(len(self.qidxs)):
|
||||
# do not use query cluster,
|
||||
# those images are potentially positive
|
||||
qcluster = self.clusters[self.qidxs[q]]
|
||||
clusters = [qcluster]
|
||||
nidxs = []
|
||||
r = 0
|
||||
while len(nidxs) < self.nnum:
|
||||
potential = idxs2images[ranks[r, q]]
|
||||
# take at most one image from the same cluster
|
||||
if not self.clusters[potential] in clusters:
|
||||
nidxs.append(potential)
|
||||
clusters.append(self.clusters[potential])
|
||||
avg_ndist += torch.pow(qvecs[:,q]-poolvecs[:,ranks[r, q]]+1e-6, 2).sum(dim=0).sqrt()
|
||||
n_ndist += 1
|
||||
r += 1
|
||||
self.nidxs.append(nidxs)
|
||||
print('>>>> Average negative l2-distance: {:.2f}'.format(avg_ndist/n_ndist))
|
||||
print('>>>> Done')
|
||||
|
||||
return (avg_ndist/n_ndist).item() # return average negative l2-distance
|
Reference in New Issue
Block a user