This commit is contained in:
lee
2024-11-27 15:37:10 +08:00
commit 3a5214c796
696 changed files with 56947 additions and 0 deletions

101
contrast/search.py Normal file
View File

@ -0,0 +1,101 @@
import pdb
class ImgSearch():
def __init__(self):
self.search_params = {
"metric_type": "COSINE",
}
def get_max(self, a, b):
if a > b:
return a
else:
return b
def check_keys(self, dict1_last, dict2):
for key2 in list(dict2.keys()):
if key2 in list(dict1_last.keys()):
value = self.get_max(dict1_last[key2], dict2[key2])
dict1_last[key2] = value
else:
dict1_last[key2] = dict2[key2]
return dict1_last
def result_analysis(self, result, top1_flag=False):
result_dict = dict() ## 将同一barcode所有图片比对结果保存到该字典
for hits in result:
for hit in hits:
if not hit.id in result_dict: ## barcodehit.id不在结果字典中
result_dict.update({hit.id: round(hit.distance, 2)})
else: ## 将同一barcode相似度保存较高的
distance = result_dict.get(hit.id)
distance_new = self.get_max(distance, hit.distance)
result_dict.update({hit.id: round(distance_new, 2)})
if top1_flag:
return result_dict
else:
## 将所有barcode相似度结果排序存储
if len(result_dict) > 10:
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True)[:10])
else:
result_sort_dict = dict(sorted(result_dict.items(), key=lambda x: x[1], reverse=True))
return result_sort_dict
def result_update(self, temp_result, last_result):
temp_keys = list(temp_result.keys())
last_keys = list(last_result.keys())
for ke in temp_keys:
temp_value = temp_result[ke]
if ke in last_keys: ## track_id1的结果和track_id2的结果有公共barcodetrack_id2中barcode相似度高才更新
last_value = last_result[ke]
if temp_value > last_value:
last_result.update({ke: temp_value})
else: ## track_id1的结果和track_id2的结果无公共barcode
last_result.update({ke: temp_value})
return last_result
def mainSearch10(self, mainMilvus, queBarIdList, queueFeatures): ###queueBarIdList->传入的box barcode-track_Id
result_last = dict()
for i in range(len(queBarIdList)):
vectorsSearch = queueFeatures[i]
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', param=self.search_params, limit=10)
result_sort_dic = self.result_analysis(result)
result_last.update({queBarIdList[i]: result_sort_dic})
return result_last
def tempSearch(self, tempMilvus, queueList, queueFeatures, barIdList, tempbarId):
newBarList = []
### tempbarId格式->[macID_barcode_trackId1,..., macID_barcode_trackIdn]
for bar in tempbarId: ### 找出barIdList和tempbarId中共有的barcode
if len(bar.split('_')) == 3:
mac_barcode = bar.split('_')[0] + '_' + bar.split('_')[1]
if mac_barcode in barIdList:
newBarList.append(bar) ## newBarList格式->[macID_barcode_trackId1,..., macID_barcode_trackIdm]
if len(newBarList) == 0:
return {}
else:
expr = f"pk in {newBarList}"
result_last = dict()
for i in range(len(queueList)):
vectorsSearch = queueFeatures[i]
result = tempMilvus.search(vectorsSearch, anns_field='embeddings', expr=expr, param=self.search_params,
limit=len(newBarList))
result_sort_dic = self.result_analysis(result)
result_last.update({queueList[i]: result_sort_dic})
return result_last
def mainSearch1(self, mainMilvus, queBarIdList, queFeatures): ###queueBarIdList->传入的box macID_barcode_trackId
result_last = dict()
for i in range(len(queBarIdList)):
pk_barcode = queBarIdList[i].split('_')[1] #### 解析barcode 查询图片名称为macID_barcode_trackId
vectorsSearch = queFeatures[i]
result = mainMilvus.search(vectorsSearch, anns_field='embeddings', expr=f"pk=='{pk_barcode}'",
param=self.search_params, limit=1)
result_dic = self.result_analysis(result, top1_flag=True)
if (len(result_dic) != 0) and (len(result_last) != 0):
result_last = self.result_update(result_dic, result_last)
else:
result_last.update({key: value for key, value in result_dic.items()})
if len(result_last) == 0:
pk_barcode = queBarIdList[0].split('_')[1]
result_last.update({pk_barcode: 0})
return result_last