import pdb class ImgSearch(): def __init__(self): self.search_params = { "metric_type": "COSINE", } def get_max(self, a, b): if a > b: return a else: return b def check_keys(self, dict1_last, dict2): for key2 in list(dict2.keys()): if key2 in list(dict1_last.keys()): value = self.get_max(dict1_last[key2], dict2[key2]) dict1_last[key2] = value else: dict1_last[key2] = dict2[key2] return dict1_last def result_analysis(self, result, top1_flag=False): result_dict = dict() ## 将同一barcode所有图片比对结果保存到该字典 for hits in result: for hit in hits: if not hit.id in result_dict: ## barcode(hit.id)不在结果字典中 result_dict.update({hit.id: round(hit.distance, 2)}) else: ## 将同一barcode相似度保存较高的 distance = result_dict.get(hit.id) distance_new = self.get_max(distance, hit.distance) result_dict.update({hit.id: round(distance_new, 2)}) if top1_flag: return result_dict else: ## 将所有barcode相似度结果排序存储 if len(result_dict) > 10: result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True)[:10]) else: result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True)) return result_sort_dict def result_update(self, temp_result, last_result): temp_keys = list(temp_result.keys()) last_keys = list(last_result.keys()) for ke in temp_keys: temp_value = temp_result[ke] if ke in last_keys: ## track_id1的结果和track_id2的结果有公共barcode,track_id2中barcode相似度高才更新 last_value = last_result[ke] if temp_value > last_value: last_result.update({ke: temp_value}) else: ## track_id1的结果和track_id2的结果无公共barcode last_result.update({ke: temp_value}) return last_result def mainSearch10(self, mainMilvus, queBarIdList, queueFeatures): ###queueBarIdList->传入的box barcode-track_Id result_last = dict() for i in range(len(queBarIdList)): vectorsSearch = queueFeatures[i] result = mainMilvus.search(vectorsSearch, anns_field='embeddings', param=self.search_params, limit=10) result_sort_dic = self.result_analysis(result) result_last.update({queBarIdList[i]: result_sort_dic}) return result_last def tempSearch(self, tempMilvus, queueList, queueFeatures, barIdList, tempbarId): newBarList = [] ### tempbarId格式->[macID_barcode_trackId1,..., macID_barcode_trackIdn] for bar in tempbarId: ### 找出barIdList和tempbarId中共有的barcode if len(bar.split('_')) == 3: mac_barcode = bar.split('_')[0] + '_' + bar.split('_')[1] if mac_barcode in barIdList: newBarList.append(bar) ## newBarList格式->[macID_barcode_trackId1,..., macID_barcode_trackIdm] if len(newBarList) == 0: return {} else: expr = f"pk in {newBarList}" result_last = dict() for i in range(len(queueList)): vectorsSearch = queueFeatures[i] result = tempMilvus.search(vectorsSearch, anns_field='embeddings', expr=expr, param=self.search_params, limit=len(newBarList)) result_sort_dic = self.result_analysis(result) result_last.update({queueList[i]: result_sort_dic}) return result_last def mainSearch1(self, mainMilvus, queBarIdList, queFeatures): ###queueBarIdList->传入的box macID_barcode_trackId result_last = dict() for i in range(len(queBarIdList)): pk_barcode = queBarIdList[i].split('_')[1] #### 解析barcode 查询图片名称为macID_barcode_trackId vectorsSearch = queFeatures[i] result = mainMilvus.search(vectorsSearch, anns_field='embeddings', expr=f"pk=='{pk_barcode}'", param=self.search_params, limit=1) result_dic = self.result_analysis(result, top1_flag=True) if (len(result_dic) != 0) and (len(result_last) != 0): result_last = self.result_update(result_dic, result_last) else: result_last.update({key: value for key, value in result_dic.items()}) if len(result_last) == 0: pk_barcode = queBarIdList[0].split('_')[1] result_last.update({pk_barcode: 0}) return result_last