Files
2024-11-27 15:37:10 +08:00

102 lines
4.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pdb
class ImgSearch():
def __init__(self):
self.search_params = {
"metric_type": "COSINE",
}
def get_max(self, a, b):
if a > b:
return a
else:
return b
def check_keys(self, dict1_last, dict2):
for key2 in list(dict2.keys()):
if key2 in list(dict1_last.keys()):
value = self.get_max(dict1_last[key2], dict2[key2])
dict1_last[key2] = value
else:
dict1_last[key2] = dict2[key2]
return dict1_last
def result_analysis(self, result, top1_flag=False):
result_dict = dict() ## 将同一barcode所有图片比对结果保存到该字典
for hits in result:
for hit in hits:
if not hit.id in result_dict: ## barcodehit.id不在结果字典中
result_dict.update({hit.id: round(hit.distance, 2)})
else: ## 将同一barcode相似度保存较高的
distance = result_dict.get(hit.id)
distance_new = self.get_max(distance, hit.distance)
result_dict.update({hit.id: round(distance_new, 2)})
if top1_flag:
return result_dict
else:
## 将所有barcode相似度结果排序存储
if len(result_dict) > 10:
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True)[:10])
else:
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True))
return result_sort_dict
def result_update(self, temp_result, last_result):
temp_keys = list(temp_result.keys())
last_keys = list(last_result.keys())
for ke in temp_keys:
temp_value = temp_result[ke]
if ke in last_keys: ## track_id1的结果和track_id2的结果有公共barcodetrack_id2中barcode相似度高才更新
last_value = last_result[ke]
if temp_value > last_value:
last_result.update({ke: temp_value})
else: ## track_id1的结果和track_id2的结果无公共barcode
last_result.update({ke: temp_value})
return last_result
def mainSearch10(self, mainMilvus, queBarIdList, queueFeatures): ###queueBarIdList->传入的box barcode-track_Id
result_last = dict()
for i in range(len(queBarIdList)):
vectorsSearch = queueFeatures[i]
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', param=self.search_params, limit=10)
result_sort_dic = self.result_analysis(result)
result_last.update({queBarIdList[i]: result_sort_dic})
return result_last
def tempSearch(self, tempMilvus, queueList, queueFeatures, barIdList, tempbarId):
newBarList = []
### tempbarId格式->[macID_barcode_trackId1,..., macID_barcode_trackIdn]
for bar in tempbarId: ### 找出barIdList和tempbarId中共有的barcode
if len(bar.split('_')) == 3:
mac_barcode = bar.split('_')[0] + '_' + bar.split('_')[1]
if mac_barcode in barIdList:
newBarList.append(bar) ## newBarList格式->[macID_barcode_trackId1,..., macID_barcode_trackIdm]
if len(newBarList) == 0:
return {}
else:
expr = f"pk in {newBarList}"
result_last = dict()
for i in range(len(queueList)):
vectorsSearch = queueFeatures[i]
result = tempMilvus.search(vectorsSearch, anns_field='embeddings', expr=expr, param=self.search_params,
limit=len(newBarList))
result_sort_dic = self.result_analysis(result)
result_last.update({queueList[i]: result_sort_dic})
return result_last
def mainSearch1(self, mainMilvus, queBarIdList, queFeatures): ###queueBarIdList->传入的box macID_barcode_trackId
result_last = dict()
for i in range(len(queBarIdList)):
pk_barcode = queBarIdList[i].split('_')[1] #### 解析barcode 查询图片名称为macID_barcode_trackId
vectorsSearch = queFeatures[i]
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', expr=f"pk=='{pk_barcode}'",
param=self.search_params, limit=1)
result_dic = self.result_analysis(result, top1_flag=True)
if (len(result_dic) != 0) and (len(result_last) != 0):
result_last = self.result_update(result_dic, result_last)
else:
result_last.update({key: value for key, value in result_dic.items()})
if len(result_last) == 0:
pk_barcode = queBarIdList[0].split('_')[1]
result_last.update({pk_barcode: 0})
return result_last