import argparse import os import time import pickle import pdb import numpy as np import torch from torch.utils.model_zoo import load_url from torchvision import transforms from cirtorch.networks.imageretrievalnet import init_network, extract_vectors from cirtorch.datasets.testdataset import configdataset from cirtorch.utils.download import download_train, download_test from cirtorch.utils.evaluate import compute_map_and_print from cirtorch.utils.general import get_data_root, htime PRETRAINED = { 'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth', 'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth', 'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth', 'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth', 'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth', 'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth', } datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing End-to-End') # test options parser.add_argument('--network', '-n', metavar='NETWORK', help="network to be evaluated: " + " | ".join(PRETRAINED.keys())) parser.add_argument('--datasets', '-d', metavar='DATASETS', default='roxford5k,rparis6k', help="comma separated list of test datasets: " + " | ".join(datasets_names) + " (default: 'roxford5k,rparis6k')") parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N', help="maximum size of longer image side used for testing (default: 1024)") parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]', help="use multiscale vectors for testing, " + " examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')") # GPU ID parser.add_argument('--gpu-id', '-g', default='0', metavar='N', help="gpu id used for testing (default: '0')") def main(): args = parser.parse_args() # check if there are unknown datasets for dataset in args.datasets.split(','): if dataset not in datasets_names: raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset)) # check if test dataset are downloaded # and download if they are not download_train(get_data_root()) download_test(get_data_root()) # setting up the visible GPU os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id # loading network # pretrained networks (downloaded automatically) print(">> Loading network:\n>>>> '{}'".format(args.network)) state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks')) # state = torch.load(args.network) # parsing net params from meta # architecture, pooling, mean, std required # the rest has default values, in case that is doesnt exist net_params = {} net_params['architecture'] = state['meta']['architecture'] net_params['pooling'] = state['meta']['pooling'] net_params['local_whitening'] = state['meta'].get('local_whitening', False) net_params['regional'] = state['meta'].get('regional', False) net_params['whitening'] = state['meta'].get('whitening', False) net_params['mean'] = state['meta']['mean'] net_params['std'] = state['meta']['std'] net_params['pretrained'] = False # network initialization net = init_network(net_params) net.load_state_dict(state['state_dict']) print(">>>> loaded network: ") print(net.meta_repr()) # setting up the multi-scale parameters ms = list(eval(args.multiscale)) print(">>>> Evaluating scales: {}".format(ms)) # moving network to gpu and eval mode net.cuda() net.eval() # set up the transform normalize = transforms.Normalize( mean=net.meta['mean'], std=net.meta['std'] ) transform = transforms.Compose([ transforms.ToTensor(), normalize ]) # evaluate on test datasets datasets = args.datasets.split(',') for dataset in datasets: start = time.time() print('>> {}: Extracting...'.format(dataset)) # prepare config structure for the test dataset cfg = configdataset(dataset, os.path.join(get_data_root(), 'test')) images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])] qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])] try: bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])] except: bbxs = None # for holidaysmanrot and copydays # extract database and query vectors print('>> {}: database images...'.format(dataset)) vecs = extract_vectors(net, images, args.image_size, transform, ms=ms) print('>> {}: query images...'.format(dataset)) qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms) print('>> {}: Evaluating...'.format(dataset)) # convert to numpy vecs = vecs.numpy() qvecs = qvecs.numpy() # search, rank, and print scores = np.dot(vecs.T, qvecs) ranks = np.argsort(-scores, axis=0) compute_map_and_print(dataset, ranks, cfg['gnd']) print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start))) if __name__ == '__main__': main()