update
This commit is contained in:
317
contrast/test_logic.py
Normal file
317
contrast/test_logic.py
Normal file
@ -0,0 +1,317 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import pdb
|
||||
import random
|
||||
import json
|
||||
import time
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from contrast.model import resnet18, MobileNetV3_Large
|
||||
# import pymilvus
|
||||
# from pymilvus import (
|
||||
# connections,
|
||||
# utility,
|
||||
# FieldSchema, CollectionSchema, DataType,
|
||||
# Collection,
|
||||
# Milvus
|
||||
# )
|
||||
# from config import config as conf
|
||||
from contrast.search import ImgSearch
|
||||
from contrast.img_data import queueImgs_add
|
||||
import sys
|
||||
from threading import Thread
|
||||
|
||||
|
||||
sys.path.append('../tools')
|
||||
from tools.config import cfg as conf
|
||||
from tools.config import gvalue
|
||||
|
||||
|
||||
|
||||
def test_preprocess(images: list, actionModel) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
# print(img)
|
||||
try:
|
||||
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||||
res.append(im)
|
||||
except:
|
||||
continue
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def inference(images, model, actionModel):
|
||||
data = test_preprocess(images, actionModel)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(conf.device)
|
||||
features = model(data)
|
||||
return features
|
||||
|
||||
|
||||
def group_image(images, batch=64) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i:end])
|
||||
return res
|
||||
|
||||
|
||||
def barcode_state(barcodeIDList):
|
||||
with open('contrast/main_barcodes.json', 'r') as file:
|
||||
data = json.load(file)
|
||||
main_barcode = list(data.values())[0]
|
||||
barIdList_true = []
|
||||
barIdList_false = []
|
||||
for barId in barcodeIDList:
|
||||
bar = barId.split('_')[1]
|
||||
if bar in main_barcode:
|
||||
barIdList_true.append(barId)
|
||||
else:
|
||||
barIdList_false.append(barId)
|
||||
return barIdList_true, barIdList_false
|
||||
|
||||
|
||||
def getFeatureList(barList, imgList, model, actionModel):
|
||||
featList = [[] for i in range(len(barList))]
|
||||
for index, feat in enumerate(imgList):
|
||||
groups = group_image(feat)
|
||||
for group in groups:
|
||||
feat_tensor = inference(group, model, actionModel)
|
||||
for fe in feat_tensor:
|
||||
if fe.device == 'cpu':
|
||||
fe_np = fe.squeeze().detach().numpy()
|
||||
else:
|
||||
fe_np = fe.squeeze().detach().cpu().numpy()
|
||||
featList[index].append(fe_np)
|
||||
return featList
|
||||
|
||||
|
||||
def img2feature(imgs_dict, model, actionModel, barcode_flag):
|
||||
if not len(imgs_dict) > 0:
|
||||
raise ValueError("Tracking fail no images files provided")
|
||||
queBarIdList = list(imgs_dict.keys())
|
||||
if barcode_flag:
|
||||
# # ========判断barcode是否在特征库============
|
||||
queBarIdList_t, barIdList_f = barcode_state(queBarIdList)
|
||||
queFeatList_t = []
|
||||
if len(queBarIdList_t) == 0:
|
||||
print(f"All barcodes are not in the main_library: {barIdList_f}")
|
||||
return queBarIdList_t, queFeatList_t
|
||||
else:
|
||||
if len(barIdList_f) > 0: ## 将不在barcode库中的barcode及图片删除
|
||||
print(f"These barcodes are not in the main_library: {barIdList_f}")
|
||||
for bar_f in barIdList_f:
|
||||
del imgs_dict[bar_f]
|
||||
queImgList_t = list(imgs_dict.values())
|
||||
queFeatList_t = getFeatureList(queBarIdList_t, queImgList_t, model, actionModel)
|
||||
return queBarIdList_t, queFeatList_t
|
||||
else:
|
||||
queImgsList = list(imgs_dict.values())
|
||||
queFeatList = getFeatureList(queBarIdList, queImgsList, model, actionModel)
|
||||
return queBarIdList, queFeatList
|
||||
|
||||
|
||||
|
||||
# def create_milvus(collection_name, host, port, barcode_list, features):
|
||||
# # 1. connect to Milvus
|
||||
# fmt = "\n=== {:30} ===\n"
|
||||
# connections.connect('default', host=host, port=port) # 连接到 Milvus 服务器
|
||||
# has = utility.has_collection(collection_name) ##检查collection_name是否存在milvus中
|
||||
# print(f"Does collection {collection_name} exist in Milvus: {has}")
|
||||
# # if has: ## 删除collection_name的库
|
||||
# # utility.drop_collection(collection_name)
|
||||
#
|
||||
# # 2. create colllection
|
||||
# fields = [
|
||||
# FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), ###图片路径
|
||||
# FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=256)
|
||||
# ]
|
||||
# schema = CollectionSchema(fields)
|
||||
# print(fmt.format(f"Create collection {collection_name}"))
|
||||
# hello_milvus = Collection(collection_name, schema, consistency_level="Strong")
|
||||
# # 3. insert data
|
||||
# for i in range(len(features)):
|
||||
# entities = [
|
||||
# # provide the pk field because `auto_id` is set to False
|
||||
# [barcode_list[i]] * len(features[i]), ## 图片维度和向量维度需匹配 每个向量都生成一个barcode
|
||||
# features[i],
|
||||
# ]
|
||||
# print(fmt.format("Start inserting entities"))
|
||||
# insert_result = hello_milvus.insert(entities)
|
||||
# hello_milvus.flush()
|
||||
# print(f"Number of entities in {collection_name}: {hello_milvus.num_entities}") # check the num_entities
|
||||
# return hello_milvus
|
||||
|
||||
|
||||
# def load_collection(collection_name):
|
||||
# collection = Collection(collection_name)
|
||||
# # collection.release() ### 将collection从加载状态变成未加载
|
||||
# # collection.drop_index() ### 删除索引
|
||||
#
|
||||
# index_params = {
|
||||
# "index_type": "IVF_FLAT",
|
||||
# # "index_type": "IVF_SQ8",
|
||||
# # "index_type": "GPU_IVF_FLAT",
|
||||
# "metric_type": "COSINE",
|
||||
# "params": {
|
||||
# "nlist": 10000
|
||||
# }
|
||||
# }
|
||||
# #### 准确率低
|
||||
# # index_params = {
|
||||
# # "index_type": "IVF_PQ",
|
||||
# # "metric_type": "COSINE",
|
||||
# # "params": {
|
||||
# # "nlist": 99,
|
||||
# # "m": 2,
|
||||
# # "nbits": 8
|
||||
# # }
|
||||
# # }
|
||||
# collection.create_index(
|
||||
# field_name="embeddings",
|
||||
# index_params=index_params,
|
||||
# index_name="SQ8"
|
||||
# )
|
||||
# collection.load()
|
||||
# return collection
|
||||
|
||||
|
||||
# def similarity(queImgsDict, add_flag, barcode_flag, main_milvus, model, barcode_list, actionModel):
|
||||
# searchImg = ImgSearch() ## 相似度比较
|
||||
# # 将输入图片加入临时库
|
||||
# if add_flag:
|
||||
# if actionModel:
|
||||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-2]), model, actionModel, barcode_flag)
|
||||
# else:
|
||||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-2]), model, actionModel, barcode_flag)
|
||||
#
|
||||
# if barcode_flag: ### 加购 有barcode -> 输出top10和top1
|
||||
# if len(queBarIdList) == 0:
|
||||
# top10, top1 = {}, {}
|
||||
# else:
|
||||
# for bar in queBarIdList:
|
||||
# # gvalue.tempLibList.append(bar) ## 临时特征库key值为macID_barcode_trackID
|
||||
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||||
# gvalue.tempLibLists[gvalue.mac_id] += [bar] ## 临时特征库key值为macID_barcode_trackID
|
||||
# else:
|
||||
# gvalue.tempLibLists[gvalue.mac_id] = [bar]
|
||||
# # 存入临时特征库
|
||||
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList, queBarIdFeatures)
|
||||
#
|
||||
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
|
||||
# 'host': conf.host,
|
||||
# 'port': conf.port,
|
||||
# 'barcode_list': queBarIdList,
|
||||
# 'features': queBarIdFeatures})
|
||||
# thread.start()
|
||||
# start1 = time.time()
|
||||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# start2 = time.time()
|
||||
# print('search top10 time>>>> {}'.format(start2-start1))
|
||||
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# start3 = time.time()
|
||||
# print('search top1 time>>>>> {}'.format(start3-start2))
|
||||
# return top10, top1, gvalue.tempLibLists
|
||||
# else: # 加购 无barcode -> 输出top10
|
||||
# # 无barcode时,生成随机数作为字典key值
|
||||
# queBarIdList_rand = []
|
||||
# for i in range(len(queBarIdList)):
|
||||
# random_number = ''.join(random.choices('0123456789', k=10))
|
||||
# queBarIdList_rand.append(str(random_number))
|
||||
# # gvalue.tempLibList.append(str(random_number))
|
||||
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||||
# gvalue.tempLibLists[gvalue.mac_id] += [str(random_number)] ## 临时特征库key值为macID_barcode_trackID
|
||||
# else:
|
||||
# gvalue.tempLibLists[gvalue.mac_id] = [str(random_number)]
|
||||
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList_rand, queBarIdFeatures)
|
||||
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
|
||||
# 'host': conf.host,
|
||||
# 'port': conf.port,
|
||||
# 'barcode_list': queBarIdList_rand,
|
||||
# 'features': queBarIdFeatures})
|
||||
# thread.start()
|
||||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# # print(f'top10: {top10}')
|
||||
# return top10, gvalue.tempLibLists
|
||||
# else: # 退购 -> 输出top10和topn
|
||||
# if gvalue.tempLibLists.get(gvalue.mac_id) is None:
|
||||
# gvalue.tempLibList = []
|
||||
# else:
|
||||
# gvalue.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
|
||||
# ## 加载临时特征库
|
||||
# tempMilvusName = "temp_features"
|
||||
# has = utility.has_collection(tempMilvusName)
|
||||
# print(f"Does collection {tempMilvusName} exist in Milvus: {has}")
|
||||
# tempMilvus = load_collection(tempMilvusName)
|
||||
# print(f"Number of entities in {tempMilvusName}: {tempMilvus.num_entities}")
|
||||
# if actionModel:
|
||||
# barcode_list = barcode_list
|
||||
# else:
|
||||
# barcode_list = queueImgs_add['barcode_list']
|
||||
# if actionModel:
|
||||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-1]), model, actionModel, barcode_flag)
|
||||
# else:
|
||||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-3]), model, actionModel, barcode_flag)
|
||||
# if barcode_flag:
|
||||
# if len(queBarIdList) == 0:
|
||||
# top10, top1, top_n = {}, {}, {}
|
||||
# else:
|
||||
# start1 = time.time()
|
||||
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# start2 = time.time()
|
||||
# print('search top1 time>>>> {}'.format(start2 - start1))
|
||||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# start3 = time.time()
|
||||
# print('search top10 time>>>> {}'.format(start3 - start2))
|
||||
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
|
||||
# # print(f'top10: {top10}, top1: {top1}, topn: {top_n}')
|
||||
# return top10, top1, top_n
|
||||
# else:
|
||||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||||
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
|
||||
# # print(f'top10: {top10}, topn: {top_n}')
|
||||
# return top10, top_n
|
||||
|
||||
|
||||
def similarity_interface(dataCollection):
|
||||
queImgsDict = dataCollection.queImgsDict
|
||||
add_flag = dataCollection.add_flag
|
||||
barcode_flag = dataCollection.barcode_flag
|
||||
main_milvus = dataCollection.mainMilvus
|
||||
#tempLibList = dataCollection.tempLibList
|
||||
model = dataCollection.model
|
||||
actionModel = dataCollection.actionModel
|
||||
barcode_list = dataCollection.barcode_list
|
||||
#return similarity(queImgsDict, add_flag, barcode_flag, main_milvus, tempLibList, model, barcode_list, actionModel)
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
||||
# connections.connect('default', host=conf.host, port=conf.port)
|
||||
# # 加载主特征库
|
||||
# mainMilvusName = "main_features"
|
||||
# has = utility.has_collection(mainMilvusName)
|
||||
# print(f"Does collection {mainMilvusName} exist in Milvus: {has}")
|
||||
# mainMilvus = Collection(mainMilvusName)
|
||||
# mainMilvus.load()
|
||||
# model = initModel()
|
||||
# # queueImgs_add queueImgs_back 分别为加购和退购时的入参
|
||||
# add_flag = queueImgs_add['add_flag']
|
||||
# barcode_flag = queueImgs_add['barcode_flag']
|
||||
# tempLibList = [] # 临时特征库的barcodeId_list
|
||||
# # tempLibList = ['3500610085338_01', '4260290263776_01'] ##test
|
||||
# if add_flag:
|
||||
# if barcode_flag: # 加购 有barcode -> 输出top10和top1
|
||||
# top10, top1, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
||||
# print(f"top10: {top10}\ntop1: {top1}")
|
||||
# else: # 加购 无barcode -> 输出top10
|
||||
# top10, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
||||
# else: # 退购 -> 输出top10和topn
|
||||
# top10, topn = similarity(queueImgs_back, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
Reference in New Issue
Block a user