Files
ieemoo-ai-searchv2/sample_server.py
2022-11-22 15:32:06 +08:00

113 lines
4.0 KiB
Python
Executable File

import sys
import argparse
#from utils.retrieval_index import EvaluteMap
from utils.tools import EvaluteMap
from utils.retrieval_feature import AntiFraudFeatureDataset
from utils.monitor import Moniting
from utils.updateObs import *
#from utils.decide import Decide
from utils.config import cfg
from utils.tools import createNet
from flask import request,Flask, jsonify
from utils.forsegmentation import analysis
from gevent.pywsgi import WSGIServer
import os, base64, stat, shutil
import pdb
sys.path.append('RAFT')
sys.path.append('RAFT/core')
sys.path.append('RAFT/core/utils')
from RAFT.analysis_video import *
import time
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
app = Flask(__name__)
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='RAFT/models/raft-things.pth',help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
opt, unknown = parser.parse_known_args()
'''
status 状态码
00: 视频未解析成功(视频截取错误)
01: 未纳入监查列表
02: 未检测出商品
03: 异常输出
04: 正确识别
'''
status = ['00', '01', '02', '03', '04']
net, transform, ms = createNet()
raft_model = raft_init_model(opt)
def get_video():
url = "https://api.ieemoo.com/emoo-train/collection/getVideoCollectByTime.do"
data = {"startTime":"2022-01-25", "endTime":"2022-01-26"}
r = requests.post(url=url, data=data)
videonames = []
filename = cfg.SAVIDEOPATH
for dictdata in r.json()['data']:
urlpath = dictdata["videoPath"]
videonames.append(urlpath)
for urlname in videonames:
videoname = os.path.basename(urlname)
savepath = os.sep.join([filename, videoname])
filepath, _ = urllib.request.urlretrieve(urlname, savepath, _progress)
def search(video_name):
#get_video()
T1 = time.time()
pre_status = False
try:
video_path = os.sep.join([cfg.SAVIDEOPATH, video_name])
uuid_barcode = video_name.split('.')[0]
barcode_name = uuid_barcode.split('_')[-1]
#pdb.set_trace()
photo_nu = analysis_video(raft_model, video_path, cfg.SAMPLEIMGS, uuid_barcode)
if not Moniting(barcode_name).search() == 'nomatch':
if photo_nu == 0:
deleteimg(uuid_barcode)
return uuid_barcode+'_0.90_!'+status[0]+'_'+video_name
#Addimg(uuid_barcode)
feature_dict = AntiFraudFeatureDataset(uuid_barcode, cfg.SAMPLEIMGS, 'sample').extractFeature(net, transform, ms)
res = EvaluteMap().match_images(feature_dict, barcode_name)
if res<cfg.THRESHOLD: pre_status = status[2]
else: pre_status = status[4]
else:
pre_status = status[1]
res = '0.90'
except:
return uuid_barcode+'_0.90_!'+'_'+status[3]+'_'+video_name
data = uuid_barcode+'_'+str(res)+'_!'
print(data)
if pre_status == '04':#去除异常与识别正确
deleteimg(uuid_barcode)
result = data+'_'+pre_status+'_'+video_name
T2 = time.time()
print('程序运行总时间:%s' % ((T2 - T1) ))
print(result)
return result
def match():
n = 0
total = len(os.listdir(cfg.SAVIDEOPATH))
f = open('tmp.txt', 'a')
for video_name in os.listdir(cfg.SAVIDEOPATH):
result = search(video_name)
score = result.split('!')[0].split('_')[-2]
if float(score) >cfg.THRESHOLD:
if not float(score) == 0.90:
#print('video_name',video_name)
f.write(result+'\n')
n += 1
else:
total -= 1
if not n == 0:
print(n/total)
f.close()
def deleteimg(uuid_barcode):
for img_name in os.listdir(cfg.SAMPLEIMGS):
if uuid_barcode in img_name:
os.remove(os.sep.join([cfg.SAMPLEIMGS, img_name]))
if __name__ == '__main__':
match()