102 lines
4.8 KiB
Python
102 lines
4.8 KiB
Python
import pdb
|
||
|
||
|
||
class ImgSearch():
|
||
def __init__(self):
|
||
self.search_params = {
|
||
"metric_type": "COSINE",
|
||
}
|
||
|
||
def get_max(self, a, b):
|
||
if a > b:
|
||
return a
|
||
else:
|
||
return b
|
||
|
||
def check_keys(self, dict1_last, dict2):
|
||
for key2 in list(dict2.keys()):
|
||
if key2 in list(dict1_last.keys()):
|
||
value = self.get_max(dict1_last[key2], dict2[key2])
|
||
dict1_last[key2] = value
|
||
else:
|
||
dict1_last[key2] = dict2[key2]
|
||
return dict1_last
|
||
def result_analysis(self, result, top1_flag=False):
|
||
result_dict = dict() ## 将同一barcode所有图片比对结果保存到该字典
|
||
for hits in result:
|
||
for hit in hits:
|
||
if not hit.id in result_dict: ## barcode(hit.id)不在结果字典中
|
||
result_dict.update({hit.id: round(hit.distance, 2)})
|
||
else: ## 将同一barcode相似度保存较高的
|
||
distance = result_dict.get(hit.id)
|
||
distance_new = self.get_max(distance, hit.distance)
|
||
result_dict.update({hit.id: round(distance_new, 2)})
|
||
if top1_flag:
|
||
return result_dict
|
||
else:
|
||
## 将所有barcode相似度结果排序存储
|
||
if len(result_dict) > 10:
|
||
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True)[:10])
|
||
else:
|
||
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True))
|
||
return result_sort_dict
|
||
def result_update(self, temp_result, last_result):
|
||
temp_keys = list(temp_result.keys())
|
||
last_keys = list(last_result.keys())
|
||
for ke in temp_keys:
|
||
temp_value = temp_result[ke]
|
||
if ke in last_keys: ## track_id1的结果和track_id2的结果有公共barcode,track_id2中barcode相似度高才更新
|
||
last_value = last_result[ke]
|
||
if temp_value > last_value:
|
||
last_result.update({ke: temp_value})
|
||
else: ## track_id1的结果和track_id2的结果无公共barcode
|
||
last_result.update({ke: temp_value})
|
||
return last_result
|
||
def mainSearch10(self, mainMilvus, queBarIdList, queueFeatures): ###queueBarIdList->传入的box barcode-track_Id
|
||
result_last = dict()
|
||
for i in range(len(queBarIdList)):
|
||
vectorsSearch = queueFeatures[i]
|
||
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', param=self.search_params, limit=10)
|
||
result_sort_dic = self.result_analysis(result)
|
||
result_last.update({queBarIdList[i]: result_sort_dic})
|
||
return result_last
|
||
|
||
def tempSearch(self, tempMilvus, queueList, queueFeatures, barIdList, tempbarId):
|
||
newBarList = []
|
||
### tempbarId格式->[macID_barcode_trackId1,..., macID_barcode_trackIdn]
|
||
for bar in tempbarId: ### 找出barIdList和tempbarId中共有的barcode
|
||
if len(bar.split('_')) == 3:
|
||
mac_barcode = bar.split('_')[0] + '_' + bar.split('_')[1]
|
||
if mac_barcode in barIdList:
|
||
newBarList.append(bar) ## newBarList格式->[macID_barcode_trackId1,..., macID_barcode_trackIdm]
|
||
if len(newBarList) == 0:
|
||
return {}
|
||
else:
|
||
expr = f"pk in {newBarList}"
|
||
result_last = dict()
|
||
for i in range(len(queueList)):
|
||
vectorsSearch = queueFeatures[i]
|
||
result = tempMilvus.search(vectorsSearch, anns_field='embeddings', expr=expr, param=self.search_params,
|
||
limit=len(newBarList))
|
||
result_sort_dic = self.result_analysis(result)
|
||
result_last.update({queueList[i]: result_sort_dic})
|
||
return result_last
|
||
|
||
def mainSearch1(self, mainMilvus, queBarIdList, queFeatures): ###queueBarIdList->传入的box macID_barcode_trackId
|
||
result_last = dict()
|
||
for i in range(len(queBarIdList)):
|
||
pk_barcode = queBarIdList[i].split('_')[1] #### 解析barcode 查询图片名称为macID_barcode_trackId
|
||
vectorsSearch = queFeatures[i]
|
||
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', expr=f"pk=='{pk_barcode}'",
|
||
param=self.search_params, limit=1)
|
||
result_dic = self.result_analysis(result, top1_flag=True)
|
||
if (len(result_dic) != 0) and (len(result_last) != 0):
|
||
result_last = self.result_update(result_dic, result_last)
|
||
else:
|
||
result_last.update({key: value for key, value in result_dic.items()})
|
||
if len(result_last) == 0:
|
||
pk_barcode = queBarIdList[0].split('_')[1]
|
||
result_last.update({pk_barcode: 0})
|
||
return result_last
|
||
|