266 lines
11 KiB
Python
Executable File
266 lines
11 KiB
Python
Executable File
import argparse
|
|
import os
|
|
import time
|
|
import pickle
|
|
import pdb
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.utils.model_zoo import load_url
|
|
from torchvision import transforms
|
|
|
|
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
|
|
from cirtorch.datasets.datahelpers import cid2filename
|
|
from cirtorch.datasets.testdataset import configdataset
|
|
from cirtorch.utils.download import download_train, download_test
|
|
from cirtorch.utils.whiten import whitenlearn, whitenapply
|
|
from cirtorch.utils.evaluate import compute_map_and_print
|
|
from cirtorch.utils.general import get_data_root, htime
|
|
|
|
PRETRAINED = {
|
|
'retrievalSfM120k-vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-vgg16-gem-b4dcdc6.pth',
|
|
'retrievalSfM120k-resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-resnet101-gem-b80fb85.pth',
|
|
# new networks with whitening learned end-to-end
|
|
'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
|
|
'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
|
|
'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth',
|
|
'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
|
|
'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
|
|
'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth',
|
|
}
|
|
|
|
datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
|
|
whitening_names = ['retrieval-SfM-30k', 'retrieval-SfM-120k']
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing')
|
|
|
|
# network
|
|
group = parser.add_mutually_exclusive_group(required=True)
|
|
group.add_argument('--network-path', '-npath', metavar='NETWORK',
|
|
help="pretrained network or network path (destination where network is saved)")
|
|
group.add_argument('--network-offtheshelf', '-noff', metavar='NETWORK',
|
|
help="off-the-shelf network, in the format 'ARCHITECTURE-POOLING' or 'ARCHITECTURE-POOLING-{reg-lwhiten-whiten}'," +
|
|
" examples: 'resnet101-gem' | 'resnet101-gem-reg' | 'resnet101-gem-whiten' | 'resnet101-gem-lwhiten' | 'resnet101-gem-reg-whiten'")
|
|
|
|
# test options
|
|
parser.add_argument('--datasets', '-d', metavar='DATASETS', default='oxford5k,paris6k',
|
|
help="comma separated list of test datasets: " +
|
|
" | ".join(datasets_names) +
|
|
" (default: 'oxford5k,paris6k')")
|
|
parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N',
|
|
help="maximum size of longer image side used for testing (default: 1024)")
|
|
parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]',
|
|
help="use multiscale vectors for testing, " +
|
|
" examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')")
|
|
parser.add_argument('--whitening', '-w', metavar='WHITENING', default=None, choices=whitening_names,
|
|
help="dataset used to learn whitening for testing: " +
|
|
" | ".join(whitening_names) +
|
|
" (default: None)")
|
|
|
|
# GPU ID
|
|
parser.add_argument('--gpu-id', '-g', default='0', metavar='N',
|
|
help="gpu id used for testing (default: '0')")
|
|
|
|
def main():
|
|
args = parser.parse_args()
|
|
|
|
# check if there are unknown datasets
|
|
for dataset in args.datasets.split(','):
|
|
if dataset not in datasets_names:
|
|
raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset))
|
|
|
|
# check if test dataset are downloaded
|
|
# and download if they are not
|
|
download_train(get_data_root())
|
|
download_test(get_data_root())
|
|
|
|
# setting up the visible GPU
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
|
|
|
|
# loading network from path
|
|
if args.network_path is not None:
|
|
|
|
print(">> Loading network:\n>>>> '{}'".format(args.network_path))
|
|
if args.network_path in PRETRAINED:
|
|
# pretrained networks (downloaded automatically)
|
|
state = load_url(PRETRAINED[args.network_path], model_dir=os.path.join(get_data_root(), 'networks'))
|
|
else:
|
|
# fine-tuned network from path
|
|
state = torch.load(args.network_path)
|
|
|
|
# parsing net params from meta
|
|
# architecture, pooling, mean, std required
|
|
# the rest has default values, in case that is doesnt exist
|
|
net_params = {}
|
|
net_params['architecture'] = state['meta']['architecture']
|
|
net_params['pooling'] = state['meta']['pooling']
|
|
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
|
|
net_params['regional'] = state['meta'].get('regional', False)
|
|
net_params['whitening'] = state['meta'].get('whitening', False)
|
|
net_params['mean'] = state['meta']['mean']
|
|
net_params['std'] = state['meta']['std']
|
|
net_params['pretrained'] = False
|
|
|
|
# load network
|
|
net = init_network(net_params)
|
|
net.load_state_dict(state['state_dict'])
|
|
|
|
# if whitening is precomputed
|
|
if 'Lw' in state['meta']:
|
|
net.meta['Lw'] = state['meta']['Lw']
|
|
|
|
print(">>>> loaded network: ")
|
|
print(net.meta_repr())
|
|
|
|
# loading offtheshelf network
|
|
elif args.network_offtheshelf is not None:
|
|
|
|
# parse off-the-shelf parameters
|
|
offtheshelf = args.network_offtheshelf.split('-')
|
|
net_params = {}
|
|
net_params['architecture'] = offtheshelf[0]
|
|
net_params['pooling'] = offtheshelf[1]
|
|
net_params['local_whitening'] = 'lwhiten' in offtheshelf[2:]
|
|
net_params['regional'] = 'reg' in offtheshelf[2:]
|
|
net_params['whitening'] = 'whiten' in offtheshelf[2:]
|
|
net_params['pretrained'] = True
|
|
|
|
# load off-the-shelf network
|
|
print(">> Loading off-the-shelf network:\n>>>> '{}'".format(args.network_offtheshelf))
|
|
net = init_network(net_params)
|
|
print(">>>> loaded network: ")
|
|
print(net.meta_repr())
|
|
|
|
# setting up the multi-scale parameters
|
|
ms = list(eval(args.multiscale))
|
|
if len(ms)>1 and net.meta['pooling'] == 'gem' and not net.meta['regional'] and not net.meta['whitening']:
|
|
msp = net.pool.p.item()
|
|
print(">> Set-up multiscale:")
|
|
print(">>>> ms: {}".format(ms))
|
|
print(">>>> msp: {}".format(msp))
|
|
else:
|
|
msp = 1
|
|
|
|
# moving network to gpu and eval mode
|
|
net.cuda()
|
|
net.eval()
|
|
|
|
# set up the transform
|
|
normalize = transforms.Normalize(
|
|
mean=net.meta['mean'],
|
|
std=net.meta['std']
|
|
)
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
normalize
|
|
])
|
|
|
|
# compute whitening
|
|
if args.whitening is not None:
|
|
start = time.time()
|
|
|
|
if 'Lw' in net.meta and args.whitening in net.meta['Lw']:
|
|
|
|
print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening))
|
|
|
|
if len(ms)>1:
|
|
Lw = net.meta['Lw'][args.whitening]['ms']
|
|
else:
|
|
Lw = net.meta['Lw'][args.whitening]['ss']
|
|
|
|
else:
|
|
|
|
# if we evaluate networks from path we should save/load whitening
|
|
# not to compute it every time
|
|
if args.network_path is not None:
|
|
whiten_fn = args.network_path + '_{}_whiten'.format(args.whitening)
|
|
if len(ms) > 1:
|
|
whiten_fn += '_ms'
|
|
whiten_fn += '.pth'
|
|
else:
|
|
whiten_fn = None
|
|
|
|
if whiten_fn is not None and os.path.isfile(whiten_fn):
|
|
print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening))
|
|
Lw = torch.load(whiten_fn)
|
|
|
|
else:
|
|
print('>> {}: Learning whitening...'.format(args.whitening))
|
|
|
|
# loading db
|
|
db_root = os.path.join(get_data_root(), 'train', args.whitening)
|
|
ims_root = os.path.join(db_root, 'ims')
|
|
db_fn = os.path.join(db_root, '{}-whiten.pkl'.format(args.whitening))
|
|
with open(db_fn, 'rb') as f:
|
|
db = pickle.load(f)
|
|
images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
|
|
|
|
# extract whitening vectors
|
|
print('>> {}: Extracting...'.format(args.whitening))
|
|
wvecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp)
|
|
|
|
# learning whitening
|
|
print('>> {}: Learning...'.format(args.whitening))
|
|
wvecs = wvecs.numpy()
|
|
m, P = whitenlearn(wvecs, db['qidxs'], db['pidxs'])
|
|
Lw = {'m': m, 'P': P}
|
|
|
|
# saving whitening if whiten_fn exists
|
|
if whiten_fn is not None:
|
|
print('>> {}: Saving to {}...'.format(args.whitening, whiten_fn))
|
|
torch.save(Lw, whiten_fn)
|
|
|
|
print('>> {}: elapsed time: {}'.format(args.whitening, htime(time.time()-start)))
|
|
|
|
else:
|
|
Lw = None
|
|
|
|
# evaluate on test datasets
|
|
datasets = args.datasets.split(',')
|
|
for dataset in datasets:
|
|
start = time.time()
|
|
|
|
print('>> {}: Extracting...'.format(dataset))
|
|
|
|
# prepare config structure for the test dataset
|
|
cfg = configdataset(dataset, os.path.join(get_data_root(), 'test'))
|
|
images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])]
|
|
qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])]
|
|
try:
|
|
bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
|
|
except:
|
|
bbxs = None # for holidaysmanrot and copydays
|
|
|
|
# extract database and query vectors
|
|
print('>> {}: database images...'.format(dataset))
|
|
vecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp)
|
|
print('>> {}: query images...'.format(dataset))
|
|
qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms, msp=msp)
|
|
|
|
print('>> {}: Evaluating...'.format(dataset))
|
|
|
|
# convert to numpy
|
|
vecs = vecs.numpy()
|
|
qvecs = qvecs.numpy()
|
|
|
|
# search, rank, and print
|
|
scores = np.dot(vecs.T, qvecs)
|
|
ranks = np.argsort(-scores, axis=0)
|
|
compute_map_and_print(dataset, ranks, cfg['gnd'])
|
|
|
|
if Lw is not None:
|
|
# whiten the vectors
|
|
vecs_lw = whitenapply(vecs, Lw['m'], Lw['P'])
|
|
qvecs_lw = whitenapply(qvecs, Lw['m'], Lw['P'])
|
|
|
|
# search, rank, and print
|
|
scores = np.dot(vecs_lw.T, qvecs_lw)
|
|
ranks = np.argsort(-scores, axis=0)
|
|
compute_map_and_print(dataset + ' + whiten', ranks, cfg['gnd'])
|
|
|
|
print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start)))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |