first push
This commit is contained in:
246
interface.py
Normal file
246
interface.py
Normal file
@ -0,0 +1,246 @@
|
||||
# coding=utf-8
|
||||
# /usr/bin/env pythpn
|
||||
|
||||
import torch
|
||||
from torch.utils.model_zoo import load_url
|
||||
from torchvision import transforms
|
||||
from cirtorch.datasets.testdataset import configdataset
|
||||
from cirtorch.utils.download import download_train, download_test
|
||||
from cirtorch.utils.evaluate import compute_map_and_print
|
||||
from cirtorch.utils.general import get_data_root, htime
|
||||
from cirtorch.networks.imageretrievalnet_cpu import init_network, extract_vectors
|
||||
from cirtorch.datasets.datahelpers import imresize
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from flask import Flask, request
|
||||
import json, io, sys, time, traceback, argparse, logging, subprocess, pickle, os, yaml,shutil
|
||||
import cv2
|
||||
import pdb
|
||||
from werkzeug.utils import cached_property
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from multiprocessing import Pool
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return ""
|
||||
|
||||
@app.route("/images/*", methods=['GET','POST'])
|
||||
def accInsurance():
|
||||
"""
|
||||
flask request process handle
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if request.method == 'GET':
|
||||
return json.dumps({'err': 1, 'msg': 'POST only'})
|
||||
else:
|
||||
app.logger.debug("print headers------")
|
||||
headers = request.headers
|
||||
headers_info = ""
|
||||
for k, v in headers.items():
|
||||
headers_info += "{}: {}\n".format(k, v)
|
||||
app.logger.debug(headers_info)
|
||||
|
||||
app.logger.debug("print forms------")
|
||||
forms_info = ""
|
||||
for k, v in request.form.items():
|
||||
forms_info += "{}: {}\n".format(k, v)
|
||||
app.logger.debug(forms_info)
|
||||
|
||||
if 'query' not in request.files:
|
||||
return json.dumps({'err': 2, 'msg': 'query image is empty'})
|
||||
|
||||
if 'sig' not in request.form:
|
||||
return json.dumps({'err': 3, 'msg': 'sig is empty'})
|
||||
|
||||
if 'q_no' not in request.form:
|
||||
return json.dumps({'err': 4, 'msg': 'no is empty'})
|
||||
|
||||
if 'q_did' not in request.form:
|
||||
return json.dumps({'err': 5, 'msg': 'did is empty'})
|
||||
|
||||
if 'q_id' not in request.form:
|
||||
return json.dumps({'err': 6, 'msg': 'id is empty'})
|
||||
|
||||
if 'type' not in request.form:
|
||||
return json.dumps({'err': 7, 'msg': 'type is empty'})
|
||||
|
||||
img_name = request.files['query'].filename
|
||||
img_bytes = request.files['query'].read()
|
||||
img = request.files['query']
|
||||
sig = request.form['sig']
|
||||
q_no = request.form['q_no']
|
||||
q_did = request.form['q_did']
|
||||
q_id = request.form['q_id']
|
||||
type = request.form['type']
|
||||
|
||||
if str(type) not in types:
|
||||
return json.dumps({'err': 8, 'msg': 'type is not exist'})
|
||||
|
||||
if img_bytes is None:
|
||||
return json.dumps({'err': 10, 'msg': 'img is none'})
|
||||
|
||||
results = imageRetrieval().retrieval_online_v0(img, q_no, q_did, q_id, type)
|
||||
|
||||
data = dict()
|
||||
data['query'] = img_name
|
||||
data['sig'] = sig
|
||||
data['type'] = type
|
||||
data['q_no'] = q_no
|
||||
data['q_did'] = q_did
|
||||
data['q_id'] = q_id
|
||||
data['results'] = results
|
||||
|
||||
return json.dumps({'err': 0, 'msg': 'success', 'data': data})
|
||||
|
||||
except:
|
||||
app.logger.exception(sys.exc_info())
|
||||
return json.dumps({'err': 9, 'msg': 'unknow error'})
|
||||
|
||||
|
||||
class imageRetrieval():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def cosine_dist(self, x, y):
|
||||
return 100 * float(np.dot(x, y))/(np.dot(x,x)*np.dot(y,y)) ** 0.5
|
||||
|
||||
def inference(self, img):
|
||||
try:
|
||||
input = Image.open(img).convert("RGB")
|
||||
input = imresize(input, 224)
|
||||
input = transforms(input).unsqueeze()
|
||||
with torch.no_grad():
|
||||
vect = net(input)
|
||||
return vect
|
||||
except:
|
||||
print('cannot indentify error')
|
||||
|
||||
def retrieval_online_v0(self, img, q_no, q_did, q_id, type):
|
||||
# load model
|
||||
query_vect = self.inference(img)
|
||||
query_vect = list(query_vect.detach().numpy().T[0])
|
||||
|
||||
lsh = lsh_dict[str(type)]
|
||||
response = lsh.query(query_vect, num_results=1, distance_func = "cosine")
|
||||
|
||||
try:
|
||||
similar_path = response[0][0][1]
|
||||
score = np.rint(self.cosine_dist(list(query_vect), list(response[0][0][0])))
|
||||
rank_list = similar_path.split("/")
|
||||
s_id, s_did, s_no = rank_list[-1].split("_")[-1].split(".")[0], rank_list[-1].split("_")[0], rank_list[-2]
|
||||
results = [{"s_no": s_no, "r_did": s_did, "s_id": s_id, "score": score}]
|
||||
except:
|
||||
results = []
|
||||
|
||||
img_path = "/{}/{}_{}".format(q_no, q_did, q_id)
|
||||
lsh.index(query_vect, extra_data=img_path)
|
||||
lsh_dict[str(type)] = lsh
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
class initModel():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def init_model(self, network, model_dir, types):
|
||||
print(">> Loading network:\n>>>> '{}'".format(network))
|
||||
# state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
|
||||
state = torch.load(network)
|
||||
# parsing net params from meta
|
||||
# architecture, pooling, mean, std required
|
||||
# the rest has default values, in case that is doesnt exist
|
||||
net_params = {}
|
||||
net_params['architecture'] = state['meta']['architecture']
|
||||
net_params['pooling'] = state['meta']['pooling']
|
||||
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
|
||||
net_params['regional'] = state['meta'].get('regional', False)
|
||||
net_params['whitening'] = state['meta'].get('whitening', False)
|
||||
net_params['mean'] = state['meta']['mean']
|
||||
net_params['std'] = state['meta']['std']
|
||||
net_params['pretrained'] = False
|
||||
# network initialization
|
||||
net = init_network(net_params)
|
||||
net.load_state_dict(state['state_dict'])
|
||||
print(">>>> loaded network: ")
|
||||
print(net.meta_repr())
|
||||
# moving network to gpu and eval mode
|
||||
# net.cuda()
|
||||
net.eval()
|
||||
|
||||
# set up the transform
|
||||
normalize = transforms.Normalize(
|
||||
mean=net.meta['mean'],
|
||||
std=net.meta['std']
|
||||
)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
lsh_dict = dict()
|
||||
for type in types:
|
||||
with open(os.path.join(model_dir, "dataset_index_{}.pkl".format(str(type))), "rb") as f:
|
||||
lsh = pickle.load(f)
|
||||
|
||||
lsh_dict[str(type)] = lsh
|
||||
|
||||
return net, lsh_dict, transforms
|
||||
|
||||
def init(self):
|
||||
with open('config.yaml', 'r') as f:
|
||||
conf = yaml.load(f)
|
||||
|
||||
app.logger.info(conf)
|
||||
host = conf['website']['host']
|
||||
port = conf['website']['port']
|
||||
network = conf['model']['network']
|
||||
model_dir = conf['model']['model_dir']
|
||||
types = conf['model']['type']
|
||||
|
||||
net, lsh_dict, transforms = self.init_model(network, model_dir, types)
|
||||
|
||||
return host, port, net, lsh_dict, transforms, model_dir, types
|
||||
|
||||
|
||||
def job():
|
||||
for type in types:
|
||||
with open(os.path.join(model_dir, "dataset_index_{}_v0.pkl".format(str(type))), "wb") as f:
|
||||
pickle.dump(lsh_dict[str(type)], f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
start app from ssh
|
||||
"""
|
||||
scheduler = BackgroundScheduler()
|
||||
host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
|
||||
app.run(host=host, port=port, debug=True)
|
||||
print("start server {}:{}".format(host, port))
|
||||
|
||||
scheduler.add_job(job, 'interval', seconds= 30)
|
||||
scheduler.start()
|
||||
|
||||
else:
|
||||
"""
|
||||
start app from gunicorn
|
||||
"""
|
||||
scheduler = BackgroundScheduler()
|
||||
gunicorn_logger = logging.getLogger("gunicorn.error")
|
||||
app.logger.handlers = gunicorn_logger.handlers
|
||||
app.logger.setLevel(gunicorn_logger.level)
|
||||
|
||||
host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
|
||||
app.logger.info("started from gunicorn...")
|
||||
|
||||
scheduler.add_job(job, 'interval', seconds=30)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user