Files
ieemoo-ai-searchv2/cirtorch/datasets/traindataset.py
2022-11-22 15:32:06 +08:00

248 lines
11 KiB
Python
Executable File

import os
import pickle
import pdb
import numpy as np
import torch
import torch.utils.data as data
from cirtorch.datasets.datahelpers import default_loader, imresize, cid2filename
from cirtorch.datasets.genericdataset import ImagesFromList
from cirtorch.utils.general import get_data_root
class TuplesDataset(data.Dataset):
"""Data loader that loads training and validation tuples of
Radenovic etal ECCV16: CNN image retrieval learns from BoW
Args:
name (string): dataset name: 'retrieval-sfm-120k'
mode (string): 'train' or 'val' for training and validation parts of dataset
imsize (int, Default: None): Defines the maximum size of longer image side
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
loader (callable, optional): A function to load an image given its path.
nnum (int, Default:5): Number of negatives for a query image in a training tuple
qsize (int, Default:1000): Number of query images, ie number of (q,p,n1,...nN) tuples, to be processed in one epoch
poolsize (int, Default:10000): Pool size for negative images re-mining
Attributes:
images (list): List of full filenames for each image
clusters (list): List of clusterID per image
qpool (list): List of all query image indexes
ppool (list): List of positive image indexes, each corresponding to query at the same position in qpool
qidxs (list): List of qsize query image indexes to be processed in an epoch
pidxs (list): List of qsize positive image indexes, each corresponding to query at the same position in qidxs
nidxs (list): List of qsize tuples of negative images
Each nidxs tuple contains nnum images corresponding to query image at the same position in qidxs
Lists qidxs, pidxs, nidxs are refreshed by calling the ``create_epoch_tuples()`` method,
ie new q-p pairs are picked and negative images are remined
"""
def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, transform=None, loader=default_loader):
if not (mode == 'train' or mode == 'val'):
raise(RuntimeError("MODE should be either train or val, passed as string"))
if name.startswith('retrieval-SfM'):
# setting up paths
#data_root = get_data_root()
#db_root = os.path.join(data_root, 'train', name)
#ims_root = os.path.join(db_root, 'ims')
db_root = '/home/lc/project/Search_By_Image_Upgrade/cirtorch/IamgeRetrieval_dataset'
ims_root = '/home/lc/project/Search_By_Image_Upgrade/cirtorch/IamgeRetrieval_dataset/train'
# loading db
db_fn = os.path.join(db_root, '{}.pkl'.format('train'))
with open(db_fn, 'rb') as f:
db = pickle.load(f)[mode]
# setting fullpath for images
self.images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
#elif name.startswith('gl'):
## TODO: NOT IMPLEMENTED YET PROPOERLY (WITH AUTOMATIC DOWNLOAD)
# setting up paths
#db_root = '/mnt/fry2/users/datasets/landmarkscvprw18/recognition/'
#ims_root = os.path.join(db_root, 'images', 'train')
# loading db
#db_fn = os.path.join(db_root, '{}.pkl'.format('train'))
#with open(db_fn, 'rb') as f:
# db = pickle.load(f)[mode]
# setting fullpath for images
self.images = [os.path.join(ims_root, db['cids'][i]) for i in range(len(db['cids']))]
else:
raise(RuntimeError("Unknown dataset name!"))
# initializing tuples dataset
self.name = name
self.mode = mode
self.imsize = imsize
self.clusters = db['cluster']
self.qpool = db['qidxs']
self.ppool = db['pidxs']
## If we want to keep only unique q-p pairs
## However, ordering of pairs will change, although that is not important
# qpidxs = list(set([(self.qidxs[i], self.pidxs[i]) for i in range(len(self.qidxs))]))
# self.qidxs = [qpidxs[i][0] for i in range(len(qpidxs))]
# self.pidxs = [qpidxs[i][1] for i in range(len(qpidxs))]
# size of training subset for an epoch
self.nnum = nnum
self.qsize = min(qsize, len(self.qpool))
self.poolsize = min(poolsize, len(self.images))
self.qidxs = None
self.pidxs = None
self.nidxs = None
self.transform = transform
self.loader = loader
self.print_freq = 10
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
images tuple (q,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs
"""
if self.__len__() == 0:
raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!"))
output = []
# query image
output.append(self.loader(self.images[self.qidxs[index]]))
# positive image
output.append(self.loader(self.images[self.pidxs[index]]))
# negative images
for i in range(len(self.nidxs[index])):
output.append(self.loader(self.images[self.nidxs[index][i]]))
if self.imsize is not None:
output = [imresize(img, self.imsize) for img in output]
if self.transform is not None:
output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))]
target = torch.Tensor([-1, 1] + [0]*len(self.nidxs[index]))
return output, target
def __len__(self):
# if not self.qidxs:
# return 0
# return len(self.qidxs)
return self.qsize
def __repr__(self):
fmt_str = self.__class__.__name__ + '\n'
fmt_str += ' Name and mode: {} {}\n'.format(self.name, self.mode)
fmt_str += ' Number of images: {}\n'.format(len(self.images))
fmt_str += ' Number of training tuples: {}\n'.format(len(self.qpool))
fmt_str += ' Number of negatives per tuple: {}\n'.format(self.nnum)
fmt_str += ' Number of tuples processed in an epoch: {}\n'.format(self.qsize)
fmt_str += ' Pool size for negative remining: {}\n'.format(self.poolsize)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def create_epoch_tuples(self, net):
print('>> Creating tuples for an epoch of {}-{}...'.format(self.name, self.mode))
print(">>>> used network: ")
print(net.meta_repr())
## ------------------------
## SELECTING POSITIVE PAIRS
## ------------------------
# draw qsize random queries for tuples
idxs2qpool = torch.randperm(len(self.qpool))[:self.qsize]
self.qidxs = [self.qpool[i] for i in idxs2qpool]
self.pidxs = [self.ppool[i] for i in idxs2qpool]
## ------------------------
## SELECTING NEGATIVE PAIRS
## ------------------------
# if nnum = 0 create dummy nidxs
# useful when only positives used for training
if self.nnum == 0:
self.nidxs = [[] for _ in range(len(self.qidxs))]
return 0
# draw poolsize random images for pool of negatives images
idxs2images = torch.randperm(len(self.images))[:self.poolsize]
# prepare network
net.cuda()
net.eval()
# no gradients computed, to reduce memory and increase speed
with torch.no_grad():
print('>> Extracting descriptors for query images...')
# prepare query loader
loader = torch.utils.data.DataLoader(
ImagesFromList(root='', images=[self.images[i] for i in self.qidxs], imsize=self.imsize, transform=self.transform),
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
)
# extract query vectors
qvecs = torch.zeros(net.meta['outputdim'], len(self.qidxs)).cuda()
for i, input in enumerate(loader):
#print('*********************',input,type(input))
#print('#######################',type(input))
qvecs[:, i] = net(input[0].cuda()).data.squeeze()
if (i+1) % self.print_freq == 0 or (i+1) == len(self.qidxs):
print('\r>>>> {}/{} done...'.format(i+1, len(self.qidxs)), end='')
print('')
print('>> Extracting descriptors for negative pool...')
# prepare negative pool data loader
loader = torch.utils.data.DataLoader(
ImagesFromList(root='', images=[self.images[i] for i in idxs2images], imsize=self.imsize, transform=self.transform),
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
)
# extract negative pool vectors
poolvecs = torch.zeros(net.meta['outputdim'], len(idxs2images)).cuda()
for i, input in enumerate(loader):
poolvecs[:, i] = net(input[0].cuda()).data.squeeze()
if (i+1) % self.print_freq == 0 or (i+1) == len(idxs2images):
print('\r>>>> {}/{} done...'.format(i+1, len(idxs2images)), end='')
print('')
print('>> Searching for hard negatives...')
# compute dot product scores and ranks on GPU
scores = torch.mm(poolvecs.t(), qvecs)
scores, ranks = torch.sort(scores, dim=0, descending=True)
avg_ndist = torch.tensor(0).float().cuda() # for statistics
n_ndist = torch.tensor(0).float().cuda() # for statistics
# selection of negative examples
self.nidxs = []
for q in range(len(self.qidxs)):
# do not use query cluster,
# those images are potentially positive
qcluster = self.clusters[self.qidxs[q]]
clusters = [qcluster]
nidxs = []
r = 0
while len(nidxs) < self.nnum:
potential = idxs2images[ranks[r, q]]
# take at most one image from the same cluster
if not self.clusters[potential] in clusters:
nidxs.append(potential)
clusters.append(self.clusters[potential])
avg_ndist += torch.pow(qvecs[:,q]-poolvecs[:,ranks[r, q]]+1e-6, 2).sum(dim=0).sqrt()
n_ndist += 1
r += 1
self.nidxs.append(nidxs)
print('>>>> Average negative l2-distance: {:.2f}'.format(avg_ndist/n_ndist))
print('>>>> Done')
return (avg_ndist/n_ndist).item() # return average negative l2-distance