154 lines
5.6 KiB
Python
154 lines
5.6 KiB
Python
import os
|
|
import numpy as np
|
|
from usearch.index import Index
|
|
import json
|
|
import statistics
|
|
|
|
|
|
def create_index():
|
|
index = Index(
|
|
ndim=256,
|
|
metric='cos',
|
|
# dtype='f32',
|
|
dtype='f16',
|
|
connectivity=32,
|
|
expansion_add=40,#128,
|
|
expansion_search=10,#64,
|
|
multi=True
|
|
)
|
|
return index
|
|
|
|
def compare_feature(features1, features2, model = '1'):
|
|
"""
|
|
:param model 比对策略
|
|
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
|
|
'1':带对比的所有相似度的均值
|
|
'2':比对1:1的最大值
|
|
:param feature1:
|
|
:param feature2:
|
|
:return:
|
|
"""
|
|
similarity_group, similarity_groups = [], []
|
|
if model == '0':
|
|
for feature1 in features1:
|
|
for feature2 in features2[0]:
|
|
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
|
similarity_group.append(similarity)
|
|
similarity_groups.append(max(similarity_group))
|
|
similarity_group = []
|
|
return sum(similarity_groups)/len(similarity_groups)
|
|
|
|
elif model == '1':
|
|
feature2 = features2[0]
|
|
for feature1 in features1:
|
|
for num in range(len(feature2)):
|
|
similarity = np.dot(feature1, feature2[num]) / (np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
|
similarity_group.append(similarity)
|
|
similarity_groups.append(sum(similarity_group) / len(similarity_group))
|
|
similarity_group = []
|
|
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
|
|
if len(similarity_groups) == 0:
|
|
return -1
|
|
return sum(similarity_groups)/len(similarity_groups)
|
|
elif model == '2':
|
|
feature2 = features2[0]
|
|
for feature1 in features1:
|
|
for num in range(len(feature2)):
|
|
similarity = np.dot(feature1, feature2[num]) / (np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
|
similarity_group.append(similarity)
|
|
return max(similarity_group)
|
|
|
|
|
|
|
|
def get_barcode_feature(data):
|
|
barcode = data['key']
|
|
features = data['value']
|
|
return [barcode] * len(features), features
|
|
|
|
|
|
def analysis_file(file_path):
|
|
"""
|
|
:param file_path:
|
|
:return:
|
|
"""
|
|
barcodes, features = [], []
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
for dic in data['total']:
|
|
barcode, feature = get_barcode_feature(dic)
|
|
barcodes.append(barcode)
|
|
features.append(feature)
|
|
return barcodes, features
|
|
|
|
|
|
def create_base_index(index_file_pth=None,
|
|
barcodes=None,
|
|
features=None,
|
|
save_index_name=None):
|
|
index = create_index()
|
|
if index_file_pth is not None:
|
|
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
|
|
save_index_name = index_file_pth.split('json')[0] + 'data'
|
|
barcodes, features = analysis_file(index_file_pth)
|
|
else:
|
|
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
|
|
for barcode, feature in zip(barcodes, features):
|
|
index.add(np.array(barcode), np.array(feature))
|
|
index.save(save_index_name)
|
|
|
|
def get_feature_index(index_file_pth=None,
|
|
barcodes=None):
|
|
assert index_file_pth is not None, 'index_file_pth must be not None'
|
|
index = Index.restore(index_file_pth, view=True)
|
|
feature_lists = index.get(np.array(barcodes))
|
|
print("memory {} size {}".format(index.memory_usage, index.size))
|
|
return feature_lists
|
|
|
|
def search_in_index(query=None,
|
|
barcode=None, # barcode -> int or np.ndarray
|
|
index_name=None,
|
|
temp_index=False, # 是否为临时库
|
|
model='0',
|
|
):
|
|
if temp_index:
|
|
assert index_name is not None, 'index_name must be not None'
|
|
index = Index.restore(index_name, view=True)
|
|
if barcode is not None: # 1:1对比测试
|
|
feature_lists = index.get(np.array(barcode))
|
|
results = compare_feature(query, feature_lists)
|
|
else:
|
|
results = index.search(query, count=5)
|
|
return results
|
|
else: # 标准库
|
|
assert index_name is not None, 'index_name must be not None'
|
|
index = Index.restore(index_name, view=True)
|
|
if barcode is not None: # 1:1对比测试
|
|
feature_lists = index.get(np.array(barcode))
|
|
results = compare_feature(query, feature_lists, model)
|
|
else:
|
|
results = index.search(query, count=10)
|
|
return results
|
|
|
|
def delete_index(index_name=None, key=None, index=None):
|
|
assert key is not None, 'key must be not None'
|
|
if index is None:
|
|
assert index_name is not None, 'index_name must be not None'
|
|
index = Index.restore(index_name, view=True)
|
|
index.remove(index_name)
|
|
else:
|
|
index.remove(key)
|
|
|
|
if __name__ == '__main__':
|
|
# index_file_pth = '../search_library/data_0923.json'
|
|
# create_base_index(index_file_pth)
|
|
|
|
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
|
|
# # index_file_pth = '../search_library/data_10_normal_0718.index'
|
|
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
|
|
|
|
# check index data file
|
|
index_file_pth = '../search_library/data_0923.data'
|
|
# # get_feature_index(index_file_pth, ['6901070602818'])
|
|
get_feature_index(index_file_pth, ['6934230050105'])
|
|
|