106 lines
4.2 KiB
Python
Executable File
106 lines
4.2 KiB
Python
Executable File
import sys
|
|
import argparse
|
|
#from utils.retrieval_index import EvaluteMap
|
|
from utils.tools import EvaluteMap
|
|
from utils.retrieval_feature import AntiFraudFeatureDataset
|
|
from utils.monitor import Moniting
|
|
from utils.updateObs import *
|
|
from utils.config import cfg
|
|
from utils.tools import createNet
|
|
from flask import request,Flask
|
|
from utils.forsegmentation import analysis
|
|
from gevent.pywsgi import WSGIServer
|
|
import os, base64, stat, shutil, json, time
|
|
sys.path.append('RAFT')
|
|
sys.path.append('RAFT/core')
|
|
sys.path.append('RAFT/core/utils')
|
|
from RAFT.analysis_video import *
|
|
import logging.config
|
|
from skywalking import agent, config
|
|
from threading import Thread
|
|
|
|
SW_SERVER = os.environ.get('SW_AGENT_COLLECTOR_BACKEND_SERVICES')
|
|
SW_SERVICE_NAME = os.environ.get('SW_AGENT_NAME')
|
|
if SW_SERVER and SW_SERVICE_NAME:
|
|
config.init() #采集服务的地址,给自己的服务起个名称
|
|
#config.init(collector="123.60.56.51:11800", service='ieemoo-ai-search') #采集服务的地址,给自己的服务起个名称
|
|
agent.start()
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
|
|
|
app = Flask(__name__)
|
|
|
|
parser = argparse.ArgumentParser()
|
|
#parser.add_argument('--model', default='../module/ieemoo-ai-search/model/now/raft-things.pth',help="restore checkpoint")
|
|
parser.add_argument('--model', default='../module/ieemoo-ai-searchv2/model/now/raft-small.pth',help="restore checkpoint")
|
|
#parser.add_argument('--small', action='store_true', help='use small model')
|
|
parser.add_argument('--small', type=bool, default=True, help='use small model')
|
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
|
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
|
opt, unknown = parser.parse_known_args()
|
|
|
|
'''
|
|
status 状态码
|
|
00: 视频未解析成功(视频截取错误)
|
|
01: 未纳入监查列表
|
|
02: 未检测出商品
|
|
03: 异常输出
|
|
04: 正确识别
|
|
'''
|
|
status = ['00', '01', '02', '03', '04']
|
|
net, transform, ms = createNet()
|
|
raft_model = raft_init_model(opt)
|
|
def setup_logging(path):
|
|
if os.path.exists(path):
|
|
with open(path, 'r') as f:
|
|
config = json.load(f)
|
|
logging.config.dictConfig(config)
|
|
logger = logging.getLogger("root")
|
|
return logger
|
|
|
|
logger = setup_logging('utils/logging.json')
|
|
@app.route('/searchv2', methods=['POST', 'GET'])
|
|
def search():
|
|
pre_status = False
|
|
try:
|
|
video_name = request.form.get('video_name')
|
|
if video_name == None:
|
|
return 'Need video_name'
|
|
logger.info('get video '+video_name)
|
|
ocr_file_path = os.sep.join([cfg.Ocrtxt, video_name.split('.')[0]+'.txt'])
|
|
video_extra_info = request.form.get('video_extra_info')
|
|
if not video_extra_info is None:
|
|
with open(ocr_file_path, 'w') as f:
|
|
f.write(video_extra_info)
|
|
video_data = request.files['video']
|
|
videoPath = os.sep.join([cfg.VIDEOPATH, video_name])
|
|
video_data.save(videoPath)
|
|
uuid_barcode = video_name.split('.')[0]
|
|
barcode_name = uuid_barcode.split('_')[-1]
|
|
if Moniting(barcode_name).search() == 'nomatch':
|
|
state = status[1]
|
|
analysis_video(raft_model, videoPath, '',uuid_barcode,None,net=net, transform=transform,ms=ms, match=False)
|
|
else:
|
|
state = analysis_video(raft_model, videoPath, '',uuid_barcode,None,net=net, transform=transform,ms=ms, match=True)
|
|
result = uuid_barcode+'_'+state #参数修改返回结果
|
|
try:
|
|
thread = Thread(target=AddObs, kwargs={'file_path':videoPath, 'status':state})
|
|
thread.start()
|
|
logger.info(result)
|
|
print('result >>>>> {}'.format(result))
|
|
return result
|
|
except Exception as e:
|
|
print('Exception >>>>> {}'.format(result))
|
|
return result
|
|
except Exception as e:
|
|
logger.warning(e) #异常返回00
|
|
try:
|
|
thread = Thread(target=AddObs, kwargs={'file_path':videoPath, 'status':status[3]})
|
|
thread.start()
|
|
return uuid_barcode+'_'+status[3] #参数修改返回00
|
|
except Exception as e:
|
|
return uuid_barcode+'_'+status[3] #参数修改返回00
|
|
|
|
if __name__ == '__main__':
|
|
app.run(host='0.0.0.0', port=8085)
|
|
|