Files
ieemoo-ai-searchv2/cirtorch/examples/test_e2e.py
2022-11-22 15:32:06 +08:00

146 lines
5.9 KiB
Python
Executable File

import argparse
import os
import time
import pickle
import pdb
import numpy as np
import torch
from torch.utils.model_zoo import load_url
from torchvision import transforms
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
from cirtorch.datasets.testdataset import configdataset
from cirtorch.utils.download import download_train, download_test
from cirtorch.utils.evaluate import compute_map_and_print
from cirtorch.utils.general import get_data_root, htime
PRETRAINED = {
'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth',
'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth',
}
datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing End-to-End')
# test options
parser.add_argument('--network', '-n', metavar='NETWORK',
help="network to be evaluated: " +
" | ".join(PRETRAINED.keys()))
parser.add_argument('--datasets', '-d', metavar='DATASETS', default='roxford5k,rparis6k',
help="comma separated list of test datasets: " +
" | ".join(datasets_names) +
" (default: 'roxford5k,rparis6k')")
parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N',
help="maximum size of longer image side used for testing (default: 1024)")
parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]',
help="use multiscale vectors for testing, " +
" examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')")
# GPU ID
parser.add_argument('--gpu-id', '-g', default='0', metavar='N',
help="gpu id used for testing (default: '0')")
def main():
args = parser.parse_args()
# check if there are unknown datasets
for dataset in args.datasets.split(','):
if dataset not in datasets_names:
raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset))
# check if test dataset are downloaded
# and download if they are not
download_train(get_data_root())
download_test(get_data_root())
# setting up the visible GPU
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
# loading network
# pretrained networks (downloaded automatically)
print(">> Loading network:\n>>>> '{}'".format(args.network))
state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
# state = torch.load(args.network)
# parsing net params from meta
# architecture, pooling, mean, std required
# the rest has default values, in case that is doesnt exist
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# network initialization
net = init_network(net_params)
net.load_state_dict(state['state_dict'])
print(">>>> loaded network: ")
print(net.meta_repr())
# setting up the multi-scale parameters
ms = list(eval(args.multiscale))
print(">>>> Evaluating scales: {}".format(ms))
# moving network to gpu and eval mode
net.cuda()
net.eval()
# set up the transform
normalize = transforms.Normalize(
mean=net.meta['mean'],
std=net.meta['std']
)
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
# evaluate on test datasets
datasets = args.datasets.split(',')
for dataset in datasets:
start = time.time()
print('>> {}: Extracting...'.format(dataset))
# prepare config structure for the test dataset
cfg = configdataset(dataset, os.path.join(get_data_root(), 'test'))
images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])]
qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])]
try:
bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
except:
bbxs = None # for holidaysmanrot and copydays
# extract database and query vectors
print('>> {}: database images...'.format(dataset))
vecs = extract_vectors(net, images, args.image_size, transform, ms=ms)
print('>> {}: query images...'.format(dataset))
qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms)
print('>> {}: Evaluating...'.format(dataset))
# convert to numpy
vecs = vecs.numpy()
qvecs = qvecs.numpy()
# search, rank, and print
scores = np.dot(vecs.T, qvecs)
ranks = np.argsort(-scores, axis=0)
compute_map_and_print(dataset, ranks, cfg['gnd'])
print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start)))
if __name__ == '__main__':
main()