# -*- coding: utf-8 -*- import pdb import random import json import time import torch from PIL import Image from contrast.model import resnet18, MobileNetV3_Large # import pymilvus # from pymilvus import ( # connections, # utility, # FieldSchema, CollectionSchema, DataType, # Collection, # Milvus # ) # from config import config as conf from contrast.search import ImgSearch from contrast.img_data import queueImgs_add import sys from threading import Thread sys.path.append('../tools') from tools.config import cfg as conf from tools.config import gvalue def test_preprocess(images: list, actionModel) -> torch.Tensor: res = [] for img in images: # print(img) try: im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img)) res.append(im) except: continue data = torch.stack(res) return data def inference(images, model, actionModel): data = test_preprocess(images, actionModel) if torch.cuda.is_available(): data = data.to(conf.device) features = model(data) return features def group_image(images, batch=64) -> list: """Group image paths by batch size""" size = len(images) res = [] for i in range(0, size, batch): end = min(batch + i, size) res.append(images[i:end]) return res def barcode_state(barcodeIDList): with open('contrast/main_barcodes.json', 'r') as file: data = json.load(file) main_barcode = list(data.values())[0] barIdList_true = [] barIdList_false = [] for barId in barcodeIDList: bar = barId.split('_')[1] if bar in main_barcode: barIdList_true.append(barId) else: barIdList_false.append(barId) return barIdList_true, barIdList_false def getFeatureList(barList, imgList, model, actionModel): featList = [[] for i in range(len(barList))] for index, feat in enumerate(imgList): groups = group_image(feat) for group in groups: feat_tensor = inference(group, model, actionModel) for fe in feat_tensor: if fe.device == 'cpu': fe_np = fe.squeeze().detach().numpy() else: fe_np = fe.squeeze().detach().cpu().numpy() featList[index].append(fe_np) return featList def img2feature(imgs_dict, model, actionModel, barcode_flag): if not len(imgs_dict) > 0: raise ValueError("Tracking fail no images files provided") queBarIdList = list(imgs_dict.keys()) if barcode_flag: # # ========判断barcode是否在特征库============ queBarIdList_t, barIdList_f = barcode_state(queBarIdList) queFeatList_t = [] if len(queBarIdList_t) == 0: print(f"All barcodes are not in the main_library: {barIdList_f}") return queBarIdList_t, queFeatList_t else: if len(barIdList_f) > 0: ## 将不在barcode库中的barcode及图片删除 print(f"These barcodes are not in the main_library: {barIdList_f}") for bar_f in barIdList_f: del imgs_dict[bar_f] queImgList_t = list(imgs_dict.values()) queFeatList_t = getFeatureList(queBarIdList_t, queImgList_t, model, actionModel) return queBarIdList_t, queFeatList_t else: queImgsList = list(imgs_dict.values()) queFeatList = getFeatureList(queBarIdList, queImgsList, model, actionModel) return queBarIdList, queFeatList # def create_milvus(collection_name, host, port, barcode_list, features): # # 1. connect to Milvus # fmt = "\n=== {:30} ===\n" # connections.connect('default', host=host, port=port) # 连接到 Milvus 服务器 # has = utility.has_collection(collection_name) ##检查collection_name是否存在milvus中 # print(f"Does collection {collection_name} exist in Milvus: {has}") # # if has: ## 删除collection_name的库 # # utility.drop_collection(collection_name) # # # 2. create colllection # fields = [ # FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), ###图片路径 # FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=256) # ] # schema = CollectionSchema(fields) # print(fmt.format(f"Create collection {collection_name}")) # hello_milvus = Collection(collection_name, schema, consistency_level="Strong") # # 3. insert data # for i in range(len(features)): # entities = [ # # provide the pk field because `auto_id` is set to False # [barcode_list[i]] * len(features[i]), ## 图片维度和向量维度需匹配 每个向量都生成一个barcode # features[i], # ] # print(fmt.format("Start inserting entities")) # insert_result = hello_milvus.insert(entities) # hello_milvus.flush() # print(f"Number of entities in {collection_name}: {hello_milvus.num_entities}") # check the num_entities # return hello_milvus # def load_collection(collection_name): # collection = Collection(collection_name) # # collection.release() ### 将collection从加载状态变成未加载 # # collection.drop_index() ### 删除索引 # # index_params = { # "index_type": "IVF_FLAT", # # "index_type": "IVF_SQ8", # # "index_type": "GPU_IVF_FLAT", # "metric_type": "COSINE", # "params": { # "nlist": 10000 # } # } # #### 准确率低 # # index_params = { # # "index_type": "IVF_PQ", # # "metric_type": "COSINE", # # "params": { # # "nlist": 99, # # "m": 2, # # "nbits": 8 # # } # # } # collection.create_index( # field_name="embeddings", # index_params=index_params, # index_name="SQ8" # ) # collection.load() # return collection # def similarity(queImgsDict, add_flag, barcode_flag, main_milvus, model, barcode_list, actionModel): # searchImg = ImgSearch() ## 相似度比较 # # 将输入图片加入临时库 # if add_flag: # if actionModel: # queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-2]), model, actionModel, barcode_flag) # else: # queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-2]), model, actionModel, barcode_flag) # # if barcode_flag: ### 加购 有barcode -> 输出top10和top1 # if len(queBarIdList) == 0: # top10, top1 = {}, {} # else: # for bar in queBarIdList: # # gvalue.tempLibList.append(bar) ## 临时特征库key值为macID_barcode_trackID # if gvalue.tempLibLists.get(gvalue.mac_id) is not None: # gvalue.tempLibLists[gvalue.mac_id] += [bar] ## 临时特征库key值为macID_barcode_trackID # else: # gvalue.tempLibLists[gvalue.mac_id] = [bar] # # 存入临时特征库 # # create_milvus('temp_features', conf.host, conf.port, queBarIdList, queBarIdFeatures) # # thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features', # 'host': conf.host, # 'port': conf.port, # 'barcode_list': queBarIdList, # 'features': queBarIdFeatures}) # thread.start() # start1 = time.time() # top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures) # start2 = time.time() # print('search top10 time>>>> {}'.format(start2-start1)) # top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures) # start3 = time.time() # print('search top1 time>>>>> {}'.format(start3-start2)) # return top10, top1, gvalue.tempLibLists # else: # 加购 无barcode -> 输出top10 # # 无barcode时,生成随机数作为字典key值 # queBarIdList_rand = [] # for i in range(len(queBarIdList)): # random_number = ''.join(random.choices('0123456789', k=10)) # queBarIdList_rand.append(str(random_number)) # # gvalue.tempLibList.append(str(random_number)) # if gvalue.tempLibLists.get(gvalue.mac_id) is not None: # gvalue.tempLibLists[gvalue.mac_id] += [str(random_number)] ## 临时特征库key值为macID_barcode_trackID # else: # gvalue.tempLibLists[gvalue.mac_id] = [str(random_number)] # # create_milvus('temp_features', conf.host, conf.port, queBarIdList_rand, queBarIdFeatures) # thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features', # 'host': conf.host, # 'port': conf.port, # 'barcode_list': queBarIdList_rand, # 'features': queBarIdFeatures}) # thread.start() # top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures) # # print(f'top10: {top10}') # return top10, gvalue.tempLibLists # else: # 退购 -> 输出top10和topn # if gvalue.tempLibLists.get(gvalue.mac_id) is None: # gvalue.tempLibList = [] # else: # gvalue.tempLibList = gvalue.tempLibLists[gvalue.mac_id] # ## 加载临时特征库 # tempMilvusName = "temp_features" # has = utility.has_collection(tempMilvusName) # print(f"Does collection {tempMilvusName} exist in Milvus: {has}") # tempMilvus = load_collection(tempMilvusName) # print(f"Number of entities in {tempMilvusName}: {tempMilvus.num_entities}") # if actionModel: # barcode_list = barcode_list # else: # barcode_list = queueImgs_add['barcode_list'] # if actionModel: # queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-1]), model, actionModel, barcode_flag) # else: # queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-3]), model, actionModel, barcode_flag) # if barcode_flag: # if len(queBarIdList) == 0: # top10, top1, top_n = {}, {}, {} # else: # start1 = time.time() # top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures) # start2 = time.time() # print('search top1 time>>>> {}'.format(start2 - start1)) # top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures) # start3 = time.time() # print('search top10 time>>>> {}'.format(start3 - start2)) # top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList) # # print(f'top10: {top10}, top1: {top1}, topn: {top_n}') # return top10, top1, top_n # else: # top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures) # top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList) # # print(f'top10: {top10}, topn: {top_n}') # return top10, top_n def similarity_interface(dataCollection): queImgsDict = dataCollection.queImgsDict add_flag = dataCollection.add_flag barcode_flag = dataCollection.barcode_flag main_milvus = dataCollection.mainMilvus #tempLibList = dataCollection.tempLibList model = dataCollection.model actionModel = dataCollection.actionModel barcode_list = dataCollection.barcode_list #return similarity(queImgsDict, add_flag, barcode_flag, main_milvus, tempLibList, model, barcode_list, actionModel) return 0 if __name__ == '__main__': pass # connections.connect('default', host=conf.host, port=conf.port) # # 加载主特征库 # mainMilvusName = "main_features" # has = utility.has_collection(mainMilvusName) # print(f"Does collection {mainMilvusName} exist in Milvus: {has}") # mainMilvus = Collection(mainMilvusName) # mainMilvus.load() # model = initModel() # # queueImgs_add queueImgs_back 分别为加购和退购时的入参 # add_flag = queueImgs_add['add_flag'] # barcode_flag = queueImgs_add['barcode_flag'] # tempLibList = [] # 临时特征库的barcodeId_list # # tempLibList = ['3500610085338_01', '4260290263776_01'] ##test # if add_flag: # if barcode_flag: # 加购 有barcode -> 输出top10和top1 # top10, top1, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model) # print(f"top10: {top10}\ntop1: {top1}") # else: # 加购 无barcode -> 输出top10 # top10, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model) # else: # 退购 -> 输出top10和topn # top10, topn = similarity(queueImgs_back, add_flag, barcode_flag, mainMilvus, tempLibList, model)