Files
ieemoo-ai-searchv2/server.py
2022-11-22 15:32:06 +08:00

86 lines
3.1 KiB
Python
Executable File

import sys
import argparse
#from utils.retrieval_index import EvaluteMap
from utils.tools import EvaluteMap
from utils.retrieval_feature import AntiFraudFeatureDataset
from utils.monitor import Moniting
from utils.updateObs import *
from utils.config import cfg
from utils.tools import createNet
from flask import request,Flask, jsonify
from utils.forsegmentation import analysis
from gevent.pywsgi import WSGIServer
import os, base64, stat, shutil
sys.path.append('RAFT')
sys.path.append('RAFT/core')
sys.path.append('RAFT/core/utils')
from RAFT.analysis_video import *
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
app = Flask(__name__)
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='RAFT/models/raft-things.pth',help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
opt, unknown = parser.parse_known_args()
'''
status 状态码
00: 视频未解析成功(视频截取错误)
01: 未纳入监查列表
02: 未检测出商品
03: 异常输出
04: 正确识别
'''
status = ['00', '01', '02', '03', '04']
net, transform, ms = createNet()
raft_model = raft_init_model(opt)
@app.route('/search', methods=['POST'])
def search():
pre_status = False
try:
video_name = request.form.get('video_name')
video_data = request.files['video']
video_path = os.sep.join([cfg.VIDEOPATH, video_name])
video_data.save(video_path)
uuid_barcode = video_name.split('.')[0]
barcode_name = uuid_barcode.split('_')[-1]
photo_nu = analysis_video(raft_model, video_path, cfg.TEST_IMG_DIR, uuid_barcode)
Addimg(uuid_barcode)
if not Moniting(barcode_name).search() == 'nomatch':
if photo_nu == 0:
deleteimg(uuid_barcode)
AddObs(video_path, status[0])
return uuid_barcode+'_0.90_!_'+status[0]+'_'+video_name
#Addimg(uuid_barcode)
feature_dict = AntiFraudFeatureDataset(uuid_barcode).extractFeature(net, transform, ms)
res = EvaluteMap().match_images(feature_dict, barcode_name)
if res == 'nan':
res = '0.90'
pre_status = status[1]
if res<cfg.THRESHOLD: pre_status = status[2]
else: pre_status = status[4]
else:
pre_status = status[1]
res = '0.90'
except:
AddObs(video_path, status[3])
deleteimg(uuid_barcode)
return uuid_barcode+'_0.90_!_'+status[3]+'_'+video_name
data = uuid_barcode+'_'+str(res)+'_!'
if pre_status == '04':
deleteimg(uuid_barcode)
AddObs(video_path, pre_status)
print('result:',data)
return data+'_'+pre_status+'_'+video_name
def deleteimg(uuid_barcode):
for img_name in os.listdir(cfg.TEST_IMG_DIR):
if uuid_barcode in img_name:
os.remove(os.sep.join([cfg.TEST_IMG_DIR, img_name]))
if __name__ == '__main__':
# http_server = WSGIServer(('192.168.1.142', 6001), app)
# http_server.serve_forever()
app.run()