67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
import sys
|
|
import torch
|
|
|
|
from tools.config import cfg as conf
|
|
sys.path.append('contrast')
|
|
# from config import config as conf
|
|
# from model import resnet18, MobileNetV3_Large
|
|
from test_logic import similarity_interface
|
|
from img_data import queueImgs_add
|
|
|
|
# import pymilvus
|
|
|
|
|
|
class datacollection:
|
|
barcode_flag = None
|
|
add_flag = None
|
|
queImgsDict = None
|
|
mainMilvus = None
|
|
tempLibList = None
|
|
model = None
|
|
barcode_list = None
|
|
actionModel = True # 是否是运行模式, False是测试模式 True是运行模式
|
|
|
|
|
|
class similarityResult:
|
|
top10 = None
|
|
top1 = None
|
|
tempLibList = None
|
|
topn = None
|
|
|
|
|
|
class similarity:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def getSimilarity(self, model, dataCollection, similarityRes):
|
|
dataCollection.mainMilvus = model.milvusModel
|
|
dataCollection.model = model.similarityModel
|
|
# try:
|
|
if dataCollection.add_flag:
|
|
if dataCollection.barcode_flag: # 加购 有barcode -> 输出top10和top1
|
|
similarityRes.top10, similarityRes.top1, similarityRes.tempLibList = similarity_interface(
|
|
dataCollection)
|
|
print(f"top10: {similarityRes.top10}\ntop1: {similarityRes.top1}")
|
|
else: # 加购 无barcode -> 输出top10
|
|
similarityRes.top10, similarityRes.tempLibList = similarity_interface(dataCollection)
|
|
else: # 退购 -> 输出top10和topn
|
|
if dataCollection.barcode_flag:
|
|
similarityRes.top10, similarityRes.top1, similarityRes.topn = similarity_interface(dataCollection)
|
|
else:
|
|
similarityRes.top10, similarityRes.topn = similarity_interface(dataCollection)
|
|
return similarityRes
|
|
# except pymilvus.exceptions.SchemaNotReadyException as SchemaNotReadyException: ###当前特征库不存在
|
|
# print('pymilvus.exceptions.SchemaNotReadyException', SchemaNotReadyException)
|
|
|
|
def main():
|
|
data_collection = datacollection()
|
|
similarityRes = similarityResult()
|
|
data_collection.barcode_flag = queueImgs_add['barcode_flag']
|
|
data_collection.add_flag = queueImgs_add['add_flag']
|
|
data_collection.queImgsDict = queueImgs_add
|
|
similarity().getSimilarity(data_collection, similarityRes)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|