This commit is contained in:
lee
2024-11-27 15:37:10 +08:00
commit 3a5214c796
696 changed files with 56947 additions and 0 deletions

317
contrast/test_logic.py Normal file
View File

@ -0,0 +1,317 @@
# -*- coding: utf-8 -*-
import pdb
import random
import json
import time
import torch
from PIL import Image
from contrast.model import resnet18, MobileNetV3_Large
# import pymilvus
# from pymilvus import (
# connections,
# utility,
# FieldSchema, CollectionSchema, DataType,
# Collection,
# Milvus
# )
# from config import config as conf
from contrast.search import ImgSearch
from contrast.img_data import queueImgs_add
import sys
from threading import Thread
sys.path.append('../tools')
from tools.config import cfg as conf
from tools.config import gvalue
def test_preprocess(images: list, actionModel) -> torch.Tensor:
res = []
for img in images:
# print(img)
try:
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
res.append(im)
except:
continue
data = torch.stack(res)
return data
def inference(images, model, actionModel):
data = test_preprocess(images, actionModel)
if torch.cuda.is_available():
data = data.to(conf.device)
features = model(data)
return features
def group_image(images, batch=64) -> list:
"""Group image paths by batch size"""
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i:end])
return res
def barcode_state(barcodeIDList):
with open('contrast/main_barcodes.json', 'r') as file:
data = json.load(file)
main_barcode = list(data.values())[0]
barIdList_true = []
barIdList_false = []
for barId in barcodeIDList:
bar = barId.split('_')[1]
if bar in main_barcode:
barIdList_true.append(barId)
else:
barIdList_false.append(barId)
return barIdList_true, barIdList_false
def getFeatureList(barList, imgList, model, actionModel):
featList = [[] for i in range(len(barList))]
for index, feat in enumerate(imgList):
groups = group_image(feat)
for group in groups:
feat_tensor = inference(group, model, actionModel)
for fe in feat_tensor:
if fe.device == 'cpu':
fe_np = fe.squeeze().detach().numpy()
else:
fe_np = fe.squeeze().detach().cpu().numpy()
featList[index].append(fe_np)
return featList
def img2feature(imgs_dict, model, actionModel, barcode_flag):
if not len(imgs_dict) > 0:
raise ValueError("Tracking fail no images files provided")
queBarIdList = list(imgs_dict.keys())
if barcode_flag:
# # ========判断barcode是否在特征库============
queBarIdList_t, barIdList_f = barcode_state(queBarIdList)
queFeatList_t = []
if len(queBarIdList_t) == 0:
print(f"All barcodes are not in the main_library: {barIdList_f}")
return queBarIdList_t, queFeatList_t
else:
if len(barIdList_f) > 0: ## 将不在barcode库中的barcode及图片删除
print(f"These barcodes are not in the main_library: {barIdList_f}")
for bar_f in barIdList_f:
del imgs_dict[bar_f]
queImgList_t = list(imgs_dict.values())
queFeatList_t = getFeatureList(queBarIdList_t, queImgList_t, model, actionModel)
return queBarIdList_t, queFeatList_t
else:
queImgsList = list(imgs_dict.values())
queFeatList = getFeatureList(queBarIdList, queImgsList, model, actionModel)
return queBarIdList, queFeatList
# def create_milvus(collection_name, host, port, barcode_list, features):
# # 1. connect to Milvus
# fmt = "\n=== {:30} ===\n"
# connections.connect('default', host=host, port=port) # 连接到 Milvus 服务器
# has = utility.has_collection(collection_name) ##检查collection_name是否存在milvus中
# print(f"Does collection {collection_name} exist in Milvus: {has}")
# # if has: ## 删除collection_name的库
# # utility.drop_collection(collection_name)
#
# # 2. create colllection
# fields = [
# FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), ###图片路径
# FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=256)
# ]
# schema = CollectionSchema(fields)
# print(fmt.format(f"Create collection {collection_name}"))
# hello_milvus = Collection(collection_name, schema, consistency_level="Strong")
# # 3. insert data
# for i in range(len(features)):
# entities = [
# # provide the pk field because `auto_id` is set to False
# [barcode_list[i]] * len(features[i]), ## 图片维度和向量维度需匹配 每个向量都生成一个barcode
# features[i],
# ]
# print(fmt.format("Start inserting entities"))
# insert_result = hello_milvus.insert(entities)
# hello_milvus.flush()
# print(f"Number of entities in {collection_name}: {hello_milvus.num_entities}") # check the num_entities
# return hello_milvus
# def load_collection(collection_name):
# collection = Collection(collection_name)
# # collection.release() ### 将collection从加载状态变成未加载
# # collection.drop_index() ### 删除索引
#
# index_params = {
# "index_type": "IVF_FLAT",
# # "index_type": "IVF_SQ8",
# # "index_type": "GPU_IVF_FLAT",
# "metric_type": "COSINE",
# "params": {
# "nlist": 10000
# }
# }
# #### 准确率低
# # index_params = {
# # "index_type": "IVF_PQ",
# # "metric_type": "COSINE",
# # "params": {
# # "nlist": 99,
# # "m": 2,
# # "nbits": 8
# # }
# # }
# collection.create_index(
# field_name="embeddings",
# index_params=index_params,
# index_name="SQ8"
# )
# collection.load()
# return collection
# def similarity(queImgsDict, add_flag, barcode_flag, main_milvus, model, barcode_list, actionModel):
# searchImg = ImgSearch() ## 相似度比较
# # 将输入图片加入临时库
# if add_flag:
# if actionModel:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-2]), model, actionModel, barcode_flag)
# else:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-2]), model, actionModel, barcode_flag)
#
# if barcode_flag: ### 加购 有barcode -> 输出top10和top1
# if len(queBarIdList) == 0:
# top10, top1 = {}, {}
# else:
# for bar in queBarIdList:
# # gvalue.tempLibList.append(bar) ## 临时特征库key值为macID_barcode_trackID
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
# gvalue.tempLibLists[gvalue.mac_id] += [bar] ## 临时特征库key值为macID_barcode_trackID
# else:
# gvalue.tempLibLists[gvalue.mac_id] = [bar]
# # 存入临时特征库
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList, queBarIdFeatures)
#
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
# 'host': conf.host,
# 'port': conf.port,
# 'barcode_list': queBarIdList,
# 'features': queBarIdFeatures})
# thread.start()
# start1 = time.time()
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# start2 = time.time()
# print('search top10 time>>>> {}'.format(start2-start1))
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
# start3 = time.time()
# print('search top1 time>>>>> {}'.format(start3-start2))
# return top10, top1, gvalue.tempLibLists
# else: # 加购 无barcode -> 输出top10
# # 无barcode时生成随机数作为字典key值
# queBarIdList_rand = []
# for i in range(len(queBarIdList)):
# random_number = ''.join(random.choices('0123456789', k=10))
# queBarIdList_rand.append(str(random_number))
# # gvalue.tempLibList.append(str(random_number))
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
# gvalue.tempLibLists[gvalue.mac_id] += [str(random_number)] ## 临时特征库key值为macID_barcode_trackID
# else:
# gvalue.tempLibLists[gvalue.mac_id] = [str(random_number)]
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList_rand, queBarIdFeatures)
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
# 'host': conf.host,
# 'port': conf.port,
# 'barcode_list': queBarIdList_rand,
# 'features': queBarIdFeatures})
# thread.start()
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# # print(f'top10: {top10}')
# return top10, gvalue.tempLibLists
# else: # 退购 -> 输出top10和topn
# if gvalue.tempLibLists.get(gvalue.mac_id) is None:
# gvalue.tempLibList = []
# else:
# gvalue.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
# ## 加载临时特征库
# tempMilvusName = "temp_features"
# has = utility.has_collection(tempMilvusName)
# print(f"Does collection {tempMilvusName} exist in Milvus: {has}")
# tempMilvus = load_collection(tempMilvusName)
# print(f"Number of entities in {tempMilvusName}: {tempMilvus.num_entities}")
# if actionModel:
# barcode_list = barcode_list
# else:
# barcode_list = queueImgs_add['barcode_list']
# if actionModel:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-1]), model, actionModel, barcode_flag)
# else:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-3]), model, actionModel, barcode_flag)
# if barcode_flag:
# if len(queBarIdList) == 0:
# top10, top1, top_n = {}, {}, {}
# else:
# start1 = time.time()
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
# start2 = time.time()
# print('search top1 time>>>> {}'.format(start2 - start1))
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# start3 = time.time()
# print('search top10 time>>>> {}'.format(start3 - start2))
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
# # print(f'top10: {top10}, top1: {top1}, topn: {top_n}')
# return top10, top1, top_n
# else:
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
# # print(f'top10: {top10}, topn: {top_n}')
# return top10, top_n
def similarity_interface(dataCollection):
queImgsDict = dataCollection.queImgsDict
add_flag = dataCollection.add_flag
barcode_flag = dataCollection.barcode_flag
main_milvus = dataCollection.mainMilvus
#tempLibList = dataCollection.tempLibList
model = dataCollection.model
actionModel = dataCollection.actionModel
barcode_list = dataCollection.barcode_list
#return similarity(queImgsDict, add_flag, barcode_flag, main_milvus, tempLibList, model, barcode_list, actionModel)
return 0
if __name__ == '__main__':
pass
# connections.connect('default', host=conf.host, port=conf.port)
# # 加载主特征库
# mainMilvusName = "main_features"
# has = utility.has_collection(mainMilvusName)
# print(f"Does collection {mainMilvusName} exist in Milvus: {has}")
# mainMilvus = Collection(mainMilvusName)
# mainMilvus.load()
# model = initModel()
# # queueImgs_add queueImgs_back 分别为加购和退购时的入参
# add_flag = queueImgs_add['add_flag']
# barcode_flag = queueImgs_add['barcode_flag']
# tempLibList = [] # 临时特征库的barcodeId_list
# # tempLibList = ['3500610085338_01', '4260290263776_01'] ##test
# if add_flag:
# if barcode_flag: # 加购 有barcode -> 输出top10和top1
# top10, top1, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
# print(f"top10: {top10}\ntop1: {top1}")
# else: # 加购 无barcode -> 输出top10
# top10, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
# else: # 退购 -> 输出top10和topn
# top10, topn = similarity(queueImgs_back, add_flag, barcode_flag, mainMilvus, tempLibList, model)