318 lines
14 KiB
Python
318 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
||
import pdb
|
||
import random
|
||
import json
|
||
import time
|
||
|
||
import torch
|
||
from PIL import Image
|
||
|
||
from contrast.model import resnet18, MobileNetV3_Large
|
||
# import pymilvus
|
||
# from pymilvus import (
|
||
# connections,
|
||
# utility,
|
||
# FieldSchema, CollectionSchema, DataType,
|
||
# Collection,
|
||
# Milvus
|
||
# )
|
||
# from config import config as conf
|
||
from contrast.search import ImgSearch
|
||
from contrast.img_data import queueImgs_add
|
||
import sys
|
||
from threading import Thread
|
||
|
||
|
||
sys.path.append('../tools')
|
||
from tools.config import cfg as conf
|
||
from tools.config import gvalue
|
||
|
||
|
||
|
||
def test_preprocess(images: list, actionModel) -> torch.Tensor:
|
||
res = []
|
||
for img in images:
|
||
# print(img)
|
||
try:
|
||
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||
res.append(im)
|
||
except:
|
||
continue
|
||
data = torch.stack(res)
|
||
return data
|
||
|
||
|
||
def inference(images, model, actionModel):
|
||
data = test_preprocess(images, actionModel)
|
||
if torch.cuda.is_available():
|
||
data = data.to(conf.device)
|
||
features = model(data)
|
||
return features
|
||
|
||
|
||
def group_image(images, batch=64) -> list:
|
||
"""Group image paths by batch size"""
|
||
size = len(images)
|
||
res = []
|
||
for i in range(0, size, batch):
|
||
end = min(batch + i, size)
|
||
res.append(images[i:end])
|
||
return res
|
||
|
||
|
||
def barcode_state(barcodeIDList):
|
||
with open('contrast/main_barcodes.json', 'r') as file:
|
||
data = json.load(file)
|
||
main_barcode = list(data.values())[0]
|
||
barIdList_true = []
|
||
barIdList_false = []
|
||
for barId in barcodeIDList:
|
||
bar = barId.split('_')[1]
|
||
if bar in main_barcode:
|
||
barIdList_true.append(barId)
|
||
else:
|
||
barIdList_false.append(barId)
|
||
return barIdList_true, barIdList_false
|
||
|
||
|
||
def getFeatureList(barList, imgList, model, actionModel):
|
||
featList = [[] for i in range(len(barList))]
|
||
for index, feat in enumerate(imgList):
|
||
groups = group_image(feat)
|
||
for group in groups:
|
||
feat_tensor = inference(group, model, actionModel)
|
||
for fe in feat_tensor:
|
||
if fe.device == 'cpu':
|
||
fe_np = fe.squeeze().detach().numpy()
|
||
else:
|
||
fe_np = fe.squeeze().detach().cpu().numpy()
|
||
featList[index].append(fe_np)
|
||
return featList
|
||
|
||
|
||
def img2feature(imgs_dict, model, actionModel, barcode_flag):
|
||
if not len(imgs_dict) > 0:
|
||
raise ValueError("Tracking fail no images files provided")
|
||
queBarIdList = list(imgs_dict.keys())
|
||
if barcode_flag:
|
||
# # ========判断barcode是否在特征库============
|
||
queBarIdList_t, barIdList_f = barcode_state(queBarIdList)
|
||
queFeatList_t = []
|
||
if len(queBarIdList_t) == 0:
|
||
print(f"All barcodes are not in the main_library: {barIdList_f}")
|
||
return queBarIdList_t, queFeatList_t
|
||
else:
|
||
if len(barIdList_f) > 0: ## 将不在barcode库中的barcode及图片删除
|
||
print(f"These barcodes are not in the main_library: {barIdList_f}")
|
||
for bar_f in barIdList_f:
|
||
del imgs_dict[bar_f]
|
||
queImgList_t = list(imgs_dict.values())
|
||
queFeatList_t = getFeatureList(queBarIdList_t, queImgList_t, model, actionModel)
|
||
return queBarIdList_t, queFeatList_t
|
||
else:
|
||
queImgsList = list(imgs_dict.values())
|
||
queFeatList = getFeatureList(queBarIdList, queImgsList, model, actionModel)
|
||
return queBarIdList, queFeatList
|
||
|
||
|
||
|
||
# def create_milvus(collection_name, host, port, barcode_list, features):
|
||
# # 1. connect to Milvus
|
||
# fmt = "\n=== {:30} ===\n"
|
||
# connections.connect('default', host=host, port=port) # 连接到 Milvus 服务器
|
||
# has = utility.has_collection(collection_name) ##检查collection_name是否存在milvus中
|
||
# print(f"Does collection {collection_name} exist in Milvus: {has}")
|
||
# # if has: ## 删除collection_name的库
|
||
# # utility.drop_collection(collection_name)
|
||
#
|
||
# # 2. create colllection
|
||
# fields = [
|
||
# FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), ###图片路径
|
||
# FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=256)
|
||
# ]
|
||
# schema = CollectionSchema(fields)
|
||
# print(fmt.format(f"Create collection {collection_name}"))
|
||
# hello_milvus = Collection(collection_name, schema, consistency_level="Strong")
|
||
# # 3. insert data
|
||
# for i in range(len(features)):
|
||
# entities = [
|
||
# # provide the pk field because `auto_id` is set to False
|
||
# [barcode_list[i]] * len(features[i]), ## 图片维度和向量维度需匹配 每个向量都生成一个barcode
|
||
# features[i],
|
||
# ]
|
||
# print(fmt.format("Start inserting entities"))
|
||
# insert_result = hello_milvus.insert(entities)
|
||
# hello_milvus.flush()
|
||
# print(f"Number of entities in {collection_name}: {hello_milvus.num_entities}") # check the num_entities
|
||
# return hello_milvus
|
||
|
||
|
||
# def load_collection(collection_name):
|
||
# collection = Collection(collection_name)
|
||
# # collection.release() ### 将collection从加载状态变成未加载
|
||
# # collection.drop_index() ### 删除索引
|
||
#
|
||
# index_params = {
|
||
# "index_type": "IVF_FLAT",
|
||
# # "index_type": "IVF_SQ8",
|
||
# # "index_type": "GPU_IVF_FLAT",
|
||
# "metric_type": "COSINE",
|
||
# "params": {
|
||
# "nlist": 10000
|
||
# }
|
||
# }
|
||
# #### 准确率低
|
||
# # index_params = {
|
||
# # "index_type": "IVF_PQ",
|
||
# # "metric_type": "COSINE",
|
||
# # "params": {
|
||
# # "nlist": 99,
|
||
# # "m": 2,
|
||
# # "nbits": 8
|
||
# # }
|
||
# # }
|
||
# collection.create_index(
|
||
# field_name="embeddings",
|
||
# index_params=index_params,
|
||
# index_name="SQ8"
|
||
# )
|
||
# collection.load()
|
||
# return collection
|
||
|
||
|
||
# def similarity(queImgsDict, add_flag, barcode_flag, main_milvus, model, barcode_list, actionModel):
|
||
# searchImg = ImgSearch() ## 相似度比较
|
||
# # 将输入图片加入临时库
|
||
# if add_flag:
|
||
# if actionModel:
|
||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-2]), model, actionModel, barcode_flag)
|
||
# else:
|
||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-2]), model, actionModel, barcode_flag)
|
||
#
|
||
# if barcode_flag: ### 加购 有barcode -> 输出top10和top1
|
||
# if len(queBarIdList) == 0:
|
||
# top10, top1 = {}, {}
|
||
# else:
|
||
# for bar in queBarIdList:
|
||
# # gvalue.tempLibList.append(bar) ## 临时特征库key值为macID_barcode_trackID
|
||
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||
# gvalue.tempLibLists[gvalue.mac_id] += [bar] ## 临时特征库key值为macID_barcode_trackID
|
||
# else:
|
||
# gvalue.tempLibLists[gvalue.mac_id] = [bar]
|
||
# # 存入临时特征库
|
||
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList, queBarIdFeatures)
|
||
#
|
||
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
|
||
# 'host': conf.host,
|
||
# 'port': conf.port,
|
||
# 'barcode_list': queBarIdList,
|
||
# 'features': queBarIdFeatures})
|
||
# thread.start()
|
||
# start1 = time.time()
|
||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||
# start2 = time.time()
|
||
# print('search top10 time>>>> {}'.format(start2-start1))
|
||
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
|
||
# start3 = time.time()
|
||
# print('search top1 time>>>>> {}'.format(start3-start2))
|
||
# return top10, top1, gvalue.tempLibLists
|
||
# else: # 加购 无barcode -> 输出top10
|
||
# # 无barcode时,生成随机数作为字典key值
|
||
# queBarIdList_rand = []
|
||
# for i in range(len(queBarIdList)):
|
||
# random_number = ''.join(random.choices('0123456789', k=10))
|
||
# queBarIdList_rand.append(str(random_number))
|
||
# # gvalue.tempLibList.append(str(random_number))
|
||
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||
# gvalue.tempLibLists[gvalue.mac_id] += [str(random_number)] ## 临时特征库key值为macID_barcode_trackID
|
||
# else:
|
||
# gvalue.tempLibLists[gvalue.mac_id] = [str(random_number)]
|
||
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList_rand, queBarIdFeatures)
|
||
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
|
||
# 'host': conf.host,
|
||
# 'port': conf.port,
|
||
# 'barcode_list': queBarIdList_rand,
|
||
# 'features': queBarIdFeatures})
|
||
# thread.start()
|
||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||
# # print(f'top10: {top10}')
|
||
# return top10, gvalue.tempLibLists
|
||
# else: # 退购 -> 输出top10和topn
|
||
# if gvalue.tempLibLists.get(gvalue.mac_id) is None:
|
||
# gvalue.tempLibList = []
|
||
# else:
|
||
# gvalue.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
|
||
# ## 加载临时特征库
|
||
# tempMilvusName = "temp_features"
|
||
# has = utility.has_collection(tempMilvusName)
|
||
# print(f"Does collection {tempMilvusName} exist in Milvus: {has}")
|
||
# tempMilvus = load_collection(tempMilvusName)
|
||
# print(f"Number of entities in {tempMilvusName}: {tempMilvus.num_entities}")
|
||
# if actionModel:
|
||
# barcode_list = barcode_list
|
||
# else:
|
||
# barcode_list = queueImgs_add['barcode_list']
|
||
# if actionModel:
|
||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-1]), model, actionModel, barcode_flag)
|
||
# else:
|
||
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-3]), model, actionModel, barcode_flag)
|
||
# if barcode_flag:
|
||
# if len(queBarIdList) == 0:
|
||
# top10, top1, top_n = {}, {}, {}
|
||
# else:
|
||
# start1 = time.time()
|
||
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
|
||
# start2 = time.time()
|
||
# print('search top1 time>>>> {}'.format(start2 - start1))
|
||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||
# start3 = time.time()
|
||
# print('search top10 time>>>> {}'.format(start3 - start2))
|
||
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
|
||
# # print(f'top10: {top10}, top1: {top1}, topn: {top_n}')
|
||
# return top10, top1, top_n
|
||
# else:
|
||
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
|
||
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
|
||
# # print(f'top10: {top10}, topn: {top_n}')
|
||
# return top10, top_n
|
||
|
||
|
||
def similarity_interface(dataCollection):
|
||
queImgsDict = dataCollection.queImgsDict
|
||
add_flag = dataCollection.add_flag
|
||
barcode_flag = dataCollection.barcode_flag
|
||
main_milvus = dataCollection.mainMilvus
|
||
#tempLibList = dataCollection.tempLibList
|
||
model = dataCollection.model
|
||
actionModel = dataCollection.actionModel
|
||
barcode_list = dataCollection.barcode_list
|
||
#return similarity(queImgsDict, add_flag, barcode_flag, main_milvus, tempLibList, model, barcode_list, actionModel)
|
||
return 0
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
pass
|
||
|
||
# connections.connect('default', host=conf.host, port=conf.port)
|
||
# # 加载主特征库
|
||
# mainMilvusName = "main_features"
|
||
# has = utility.has_collection(mainMilvusName)
|
||
# print(f"Does collection {mainMilvusName} exist in Milvus: {has}")
|
||
# mainMilvus = Collection(mainMilvusName)
|
||
# mainMilvus.load()
|
||
# model = initModel()
|
||
# # queueImgs_add queueImgs_back 分别为加购和退购时的入参
|
||
# add_flag = queueImgs_add['add_flag']
|
||
# barcode_flag = queueImgs_add['barcode_flag']
|
||
# tempLibList = [] # 临时特征库的barcodeId_list
|
||
# # tempLibList = ['3500610085338_01', '4260290263776_01'] ##test
|
||
# if add_flag:
|
||
# if barcode_flag: # 加购 有barcode -> 输出top10和top1
|
||
# top10, top1, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
||
# print(f"top10: {top10}\ntop1: {top1}")
|
||
# else: # 加购 无barcode -> 输出top10
|
||
# top10, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|
||
# else: # 退购 -> 输出top10和topn
|
||
# top10, topn = similarity(queueImgs_back, add_flag, barcode_flag, mainMilvus, tempLibList, model)
|