Files
ieemoo-ai-searchv2/interface.py
2022-11-22 15:32:06 +08:00

247 lines
6.9 KiB
Python

# coding=utf-8
# /usr/bin/env pythpn
import torch
from torch.utils.model_zoo import load_url
from torchvision import transforms
from cirtorch.datasets.testdataset import configdataset
from cirtorch.utils.download import download_train, download_test
from cirtorch.utils.evaluate import compute_map_and_print
from cirtorch.utils.general import get_data_root, htime
from cirtorch.networks.imageretrievalnet_cpu import init_network, extract_vectors
from cirtorch.datasets.datahelpers import imresize
from PIL import Image
import numpy as np
import pandas as pd
from flask import Flask, request
import json, io, sys, time, traceback, argparse, logging, subprocess, pickle, os, yaml,shutil
import cv2
import pdb
from werkzeug.utils import cached_property
from apscheduler.schedulers.background import BackgroundScheduler
from multiprocessing import Pool
app = Flask(__name__)
@app.route("/")
def index():
return ""
@app.route("/images/*", methods=['GET','POST'])
def accInsurance():
"""
flask request process handle
:return:
"""
try:
if request.method == 'GET':
return json.dumps({'err': 1, 'msg': 'POST only'})
else:
app.logger.debug("print headers------")
headers = request.headers
headers_info = ""
for k, v in headers.items():
headers_info += "{}: {}\n".format(k, v)
app.logger.debug(headers_info)
app.logger.debug("print forms------")
forms_info = ""
for k, v in request.form.items():
forms_info += "{}: {}\n".format(k, v)
app.logger.debug(forms_info)
if 'query' not in request.files:
return json.dumps({'err': 2, 'msg': 'query image is empty'})
if 'sig' not in request.form:
return json.dumps({'err': 3, 'msg': 'sig is empty'})
if 'q_no' not in request.form:
return json.dumps({'err': 4, 'msg': 'no is empty'})
if 'q_did' not in request.form:
return json.dumps({'err': 5, 'msg': 'did is empty'})
if 'q_id' not in request.form:
return json.dumps({'err': 6, 'msg': 'id is empty'})
if 'type' not in request.form:
return json.dumps({'err': 7, 'msg': 'type is empty'})
img_name = request.files['query'].filename
img_bytes = request.files['query'].read()
img = request.files['query']
sig = request.form['sig']
q_no = request.form['q_no']
q_did = request.form['q_did']
q_id = request.form['q_id']
type = request.form['type']
if str(type) not in types:
return json.dumps({'err': 8, 'msg': 'type is not exist'})
if img_bytes is None:
return json.dumps({'err': 10, 'msg': 'img is none'})
results = imageRetrieval().retrieval_online_v0(img, q_no, q_did, q_id, type)
data = dict()
data['query'] = img_name
data['sig'] = sig
data['type'] = type
data['q_no'] = q_no
data['q_did'] = q_did
data['q_id'] = q_id
data['results'] = results
return json.dumps({'err': 0, 'msg': 'success', 'data': data})
except:
app.logger.exception(sys.exc_info())
return json.dumps({'err': 9, 'msg': 'unknow error'})
class imageRetrieval():
def __init__(self):
pass
def cosine_dist(self, x, y):
return 100 * float(np.dot(x, y))/(np.dot(x,x)*np.dot(y,y)) ** 0.5
def inference(self, img):
try:
input = Image.open(img).convert("RGB")
input = imresize(input, 224)
input = transforms(input).unsqueeze()
with torch.no_grad():
vect = net(input)
return vect
except:
print('cannot indentify error')
def retrieval_online_v0(self, img, q_no, q_did, q_id, type):
# load model
query_vect = self.inference(img)
query_vect = list(query_vect.detach().numpy().T[0])
lsh = lsh_dict[str(type)]
response = lsh.query(query_vect, num_results=1, distance_func = "cosine")
try:
similar_path = response[0][0][1]
score = np.rint(self.cosine_dist(list(query_vect), list(response[0][0][0])))
rank_list = similar_path.split("/")
s_id, s_did, s_no = rank_list[-1].split("_")[-1].split(".")[0], rank_list[-1].split("_")[0], rank_list[-2]
results = [{"s_no": s_no, "r_did": s_did, "s_id": s_id, "score": score}]
except:
results = []
img_path = "/{}/{}_{}".format(q_no, q_did, q_id)
lsh.index(query_vect, extra_data=img_path)
lsh_dict[str(type)] = lsh
return results
class initModel():
def __init__(self):
pass
def init_model(self, network, model_dir, types):
print(">> Loading network:\n>>>> '{}'".format(network))
# state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
state = torch.load(network)
# parsing net params from meta
# architecture, pooling, mean, std required
# the rest has default values, in case that is doesnt exist
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# network initialization
net = init_network(net_params)
net.load_state_dict(state['state_dict'])
print(">>>> loaded network: ")
print(net.meta_repr())
# moving network to gpu and eval mode
# net.cuda()
net.eval()
# set up the transform
normalize = transforms.Normalize(
mean=net.meta['mean'],
std=net.meta['std']
)
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
lsh_dict = dict()
for type in types:
with open(os.path.join(model_dir, "dataset_index_{}.pkl".format(str(type))), "rb") as f:
lsh = pickle.load(f)
lsh_dict[str(type)] = lsh
return net, lsh_dict, transforms
def init(self):
with open('config.yaml', 'r') as f:
conf = yaml.load(f)
app.logger.info(conf)
host = conf['website']['host']
port = conf['website']['port']
network = conf['model']['network']
model_dir = conf['model']['model_dir']
types = conf['model']['type']
net, lsh_dict, transforms = self.init_model(network, model_dir, types)
return host, port, net, lsh_dict, transforms, model_dir, types
def job():
for type in types:
with open(os.path.join(model_dir, "dataset_index_{}_v0.pkl".format(str(type))), "wb") as f:
pickle.dump(lsh_dict[str(type)], f)
if __name__ == "__main__":
"""
start app from ssh
"""
scheduler = BackgroundScheduler()
host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
app.run(host=host, port=port, debug=True)
print("start server {}:{}".format(host, port))
scheduler.add_job(job, 'interval', seconds= 30)
scheduler.start()
else:
"""
start app from gunicorn
"""
scheduler = BackgroundScheduler()
gunicorn_logger = logging.getLogger("gunicorn.error")
app.logger.handlers = gunicorn_logger.handlers
app.logger.setLevel(gunicorn_logger.level)
host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
app.logger.info("started from gunicorn...")
scheduler.add_job(job, 'interval', seconds=30)
scheduler.start()