first push
This commit is contained in:
145
cirtorch/examples/test_e2e.py
Executable file
145
cirtorch/examples/test_e2e.py
Executable file
@ -0,0 +1,145 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import pickle
|
||||
import pdb
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.model_zoo import load_url
|
||||
from torchvision import transforms
|
||||
|
||||
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
|
||||
from cirtorch.datasets.testdataset import configdataset
|
||||
from cirtorch.utils.download import download_train, download_test
|
||||
from cirtorch.utils.evaluate import compute_map_and_print
|
||||
from cirtorch.utils.general import get_data_root, htime
|
||||
|
||||
PRETRAINED = {
|
||||
'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
|
||||
'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
|
||||
'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth',
|
||||
'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
|
||||
'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
|
||||
'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth',
|
||||
}
|
||||
|
||||
datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing End-to-End')
|
||||
|
||||
# test options
|
||||
parser.add_argument('--network', '-n', metavar='NETWORK',
|
||||
help="network to be evaluated: " +
|
||||
" | ".join(PRETRAINED.keys()))
|
||||
parser.add_argument('--datasets', '-d', metavar='DATASETS', default='roxford5k,rparis6k',
|
||||
help="comma separated list of test datasets: " +
|
||||
" | ".join(datasets_names) +
|
||||
" (default: 'roxford5k,rparis6k')")
|
||||
parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N',
|
||||
help="maximum size of longer image side used for testing (default: 1024)")
|
||||
parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]',
|
||||
help="use multiscale vectors for testing, " +
|
||||
" examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')")
|
||||
|
||||
# GPU ID
|
||||
parser.add_argument('--gpu-id', '-g', default='0', metavar='N',
|
||||
help="gpu id used for testing (default: '0')")
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# check if there are unknown datasets
|
||||
for dataset in args.datasets.split(','):
|
||||
if dataset not in datasets_names:
|
||||
raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset))
|
||||
|
||||
# check if test dataset are downloaded
|
||||
# and download if they are not
|
||||
download_train(get_data_root())
|
||||
download_test(get_data_root())
|
||||
|
||||
# setting up the visible GPU
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
|
||||
|
||||
# loading network
|
||||
# pretrained networks (downloaded automatically)
|
||||
print(">> Loading network:\n>>>> '{}'".format(args.network))
|
||||
state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
|
||||
# state = torch.load(args.network)
|
||||
# parsing net params from meta
|
||||
# architecture, pooling, mean, std required
|
||||
# the rest has default values, in case that is doesnt exist
|
||||
net_params = {}
|
||||
net_params['architecture'] = state['meta']['architecture']
|
||||
net_params['pooling'] = state['meta']['pooling']
|
||||
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
|
||||
net_params['regional'] = state['meta'].get('regional', False)
|
||||
net_params['whitening'] = state['meta'].get('whitening', False)
|
||||
net_params['mean'] = state['meta']['mean']
|
||||
net_params['std'] = state['meta']['std']
|
||||
net_params['pretrained'] = False
|
||||
# network initialization
|
||||
net = init_network(net_params)
|
||||
net.load_state_dict(state['state_dict'])
|
||||
|
||||
print(">>>> loaded network: ")
|
||||
print(net.meta_repr())
|
||||
|
||||
# setting up the multi-scale parameters
|
||||
ms = list(eval(args.multiscale))
|
||||
print(">>>> Evaluating scales: {}".format(ms))
|
||||
|
||||
# moving network to gpu and eval mode
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# set up the transform
|
||||
normalize = transforms.Normalize(
|
||||
mean=net.meta['mean'],
|
||||
std=net.meta['std']
|
||||
)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
# evaluate on test datasets
|
||||
datasets = args.datasets.split(',')
|
||||
for dataset in datasets:
|
||||
start = time.time()
|
||||
|
||||
print('>> {}: Extracting...'.format(dataset))
|
||||
|
||||
# prepare config structure for the test dataset
|
||||
cfg = configdataset(dataset, os.path.join(get_data_root(), 'test'))
|
||||
images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])]
|
||||
qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])]
|
||||
try:
|
||||
bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
|
||||
except:
|
||||
bbxs = None # for holidaysmanrot and copydays
|
||||
|
||||
# extract database and query vectors
|
||||
print('>> {}: database images...'.format(dataset))
|
||||
vecs = extract_vectors(net, images, args.image_size, transform, ms=ms)
|
||||
print('>> {}: query images...'.format(dataset))
|
||||
qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms)
|
||||
|
||||
print('>> {}: Evaluating...'.format(dataset))
|
||||
|
||||
# convert to numpy
|
||||
vecs = vecs.numpy()
|
||||
qvecs = qvecs.numpy()
|
||||
|
||||
# search, rank, and print
|
||||
scores = np.dot(vecs.T, qvecs)
|
||||
ranks = np.argsort(-scores, axis=0)
|
||||
compute_map_and_print(dataset, ranks, cfg['gnd'])
|
||||
|
||||
print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user