import argparse import os import time import pickle import pdb import numpy as np import torch from torch.utils.model_zoo import load_url from torchvision import transforms from cirtorch.networks.imageretrievalnet import init_network, extract_vectors from cirtorch.datasets.datahelpers import cid2filename from cirtorch.datasets.testdataset import configdataset from cirtorch.utils.download import download_train, download_test from cirtorch.utils.whiten import whitenlearn, whitenapply from cirtorch.utils.evaluate import compute_map_and_print from cirtorch.utils.general import get_data_root, htime PRETRAINED = { 'retrievalSfM120k-vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-vgg16-gem-b4dcdc6.pth', 'retrievalSfM120k-resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-resnet101-gem-b80fb85.pth', # new networks with whitening learned end-to-end 'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth', 'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth', 'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth', 'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth', 'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth', 'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth', } datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] whitening_names = ['retrieval-SfM-30k', 'retrieval-SfM-120k'] parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing') # network group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--network-path', '-npath', metavar='NETWORK', help="pretrained network or network path (destination where network is saved)") group.add_argument('--network-offtheshelf', '-noff', metavar='NETWORK', help="off-the-shelf network, in the format 'ARCHITECTURE-POOLING' or 'ARCHITECTURE-POOLING-{reg-lwhiten-whiten}'," + " examples: 'resnet101-gem' | 'resnet101-gem-reg' | 'resnet101-gem-whiten' | 'resnet101-gem-lwhiten' | 'resnet101-gem-reg-whiten'") # test options parser.add_argument('--datasets', '-d', metavar='DATASETS', default='oxford5k,paris6k', help="comma separated list of test datasets: " + " | ".join(datasets_names) + " (default: 'oxford5k,paris6k')") parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N', help="maximum size of longer image side used for testing (default: 1024)") parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]', help="use multiscale vectors for testing, " + " examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')") parser.add_argument('--whitening', '-w', metavar='WHITENING', default=None, choices=whitening_names, help="dataset used to learn whitening for testing: " + " | ".join(whitening_names) + " (default: None)") # GPU ID parser.add_argument('--gpu-id', '-g', default='0', metavar='N', help="gpu id used for testing (default: '0')") def main(): args = parser.parse_args() # check if there are unknown datasets for dataset in args.datasets.split(','): if dataset not in datasets_names: raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset)) # check if test dataset are downloaded # and download if they are not download_train(get_data_root()) download_test(get_data_root()) # setting up the visible GPU os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id # loading network from path if args.network_path is not None: print(">> Loading network:\n>>>> '{}'".format(args.network_path)) if args.network_path in PRETRAINED: # pretrained networks (downloaded automatically) state = load_url(PRETRAINED[args.network_path], model_dir=os.path.join(get_data_root(), 'networks')) else: # fine-tuned network from path state = torch.load(args.network_path) # parsing net params from meta # architecture, pooling, mean, std required # the rest has default values, in case that is doesnt exist net_params = {} net_params['architecture'] = state['meta']['architecture'] net_params['pooling'] = state['meta']['pooling'] net_params['local_whitening'] = state['meta'].get('local_whitening', False) net_params['regional'] = state['meta'].get('regional', False) net_params['whitening'] = state['meta'].get('whitening', False) net_params['mean'] = state['meta']['mean'] net_params['std'] = state['meta']['std'] net_params['pretrained'] = False # load network net = init_network(net_params) net.load_state_dict(state['state_dict']) # if whitening is precomputed if 'Lw' in state['meta']: net.meta['Lw'] = state['meta']['Lw'] print(">>>> loaded network: ") print(net.meta_repr()) # loading offtheshelf network elif args.network_offtheshelf is not None: # parse off-the-shelf parameters offtheshelf = args.network_offtheshelf.split('-') net_params = {} net_params['architecture'] = offtheshelf[0] net_params['pooling'] = offtheshelf[1] net_params['local_whitening'] = 'lwhiten' in offtheshelf[2:] net_params['regional'] = 'reg' in offtheshelf[2:] net_params['whitening'] = 'whiten' in offtheshelf[2:] net_params['pretrained'] = True # load off-the-shelf network print(">> Loading off-the-shelf network:\n>>>> '{}'".format(args.network_offtheshelf)) net = init_network(net_params) print(">>>> loaded network: ") print(net.meta_repr()) # setting up the multi-scale parameters ms = list(eval(args.multiscale)) if len(ms)>1 and net.meta['pooling'] == 'gem' and not net.meta['regional'] and not net.meta['whitening']: msp = net.pool.p.item() print(">> Set-up multiscale:") print(">>>> ms: {}".format(ms)) print(">>>> msp: {}".format(msp)) else: msp = 1 # moving network to gpu and eval mode net.cuda() net.eval() # set up the transform normalize = transforms.Normalize( mean=net.meta['mean'], std=net.meta['std'] ) transform = transforms.Compose([ transforms.ToTensor(), normalize ]) # compute whitening if args.whitening is not None: start = time.time() if 'Lw' in net.meta and args.whitening in net.meta['Lw']: print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening)) if len(ms)>1: Lw = net.meta['Lw'][args.whitening]['ms'] else: Lw = net.meta['Lw'][args.whitening]['ss'] else: # if we evaluate networks from path we should save/load whitening # not to compute it every time if args.network_path is not None: whiten_fn = args.network_path + '_{}_whiten'.format(args.whitening) if len(ms) > 1: whiten_fn += '_ms' whiten_fn += '.pth' else: whiten_fn = None if whiten_fn is not None and os.path.isfile(whiten_fn): print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening)) Lw = torch.load(whiten_fn) else: print('>> {}: Learning whitening...'.format(args.whitening)) # loading db db_root = os.path.join(get_data_root(), 'train', args.whitening) ims_root = os.path.join(db_root, 'ims') db_fn = os.path.join(db_root, '{}-whiten.pkl'.format(args.whitening)) with open(db_fn, 'rb') as f: db = pickle.load(f) images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))] # extract whitening vectors print('>> {}: Extracting...'.format(args.whitening)) wvecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp) # learning whitening print('>> {}: Learning...'.format(args.whitening)) wvecs = wvecs.numpy() m, P = whitenlearn(wvecs, db['qidxs'], db['pidxs']) Lw = {'m': m, 'P': P} # saving whitening if whiten_fn exists if whiten_fn is not None: print('>> {}: Saving to {}...'.format(args.whitening, whiten_fn)) torch.save(Lw, whiten_fn) print('>> {}: elapsed time: {}'.format(args.whitening, htime(time.time()-start))) else: Lw = None # evaluate on test datasets datasets = args.datasets.split(',') for dataset in datasets: start = time.time() print('>> {}: Extracting...'.format(dataset)) # prepare config structure for the test dataset cfg = configdataset(dataset, os.path.join(get_data_root(), 'test')) images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])] qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])] try: bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])] except: bbxs = None # for holidaysmanrot and copydays # extract database and query vectors print('>> {}: database images...'.format(dataset)) vecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp) print('>> {}: query images...'.format(dataset)) qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms, msp=msp) print('>> {}: Evaluating...'.format(dataset)) # convert to numpy vecs = vecs.numpy() qvecs = qvecs.numpy() # search, rank, and print scores = np.dot(vecs.T, qvecs) ranks = np.argsort(-scores, axis=0) compute_map_and_print(dataset, ranks, cfg['gnd']) if Lw is not None: # whiten the vectors vecs_lw = whitenapply(vecs, Lw['m'], Lw['P']) qvecs_lw = whitenapply(qvecs, Lw['m'], Lw['P']) # search, rank, and print scores = np.dot(vecs_lw.T, qvecs_lw) ranks = np.argsort(-scores, axis=0) compute_map_and_print(dataset + ' + whiten', ranks, cfg['gnd']) print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start))) if __name__ == '__main__': main()