update
This commit is contained in:
153
tools/operate_usearch.py
Normal file
153
tools/operate_usearch.py
Normal file
@ -0,0 +1,153 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from usearch.index import Index
|
||||
import json
|
||||
import statistics
|
||||
|
||||
|
||||
def create_index():
|
||||
index = Index(
|
||||
ndim=256,
|
||||
metric='cos',
|
||||
# dtype='f32',
|
||||
dtype='f16',
|
||||
connectivity=32,
|
||||
expansion_add=40,#128,
|
||||
expansion_search=10,#64,
|
||||
multi=True
|
||||
)
|
||||
return index
|
||||
|
||||
def compare_feature(features1, features2, model = '1'):
|
||||
"""
|
||||
:param model 比对策略
|
||||
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
|
||||
'1':带对比的所有相似度的均值
|
||||
'2':比对1:1的最大值
|
||||
:param feature1:
|
||||
:param feature2:
|
||||
:return:
|
||||
"""
|
||||
similarity_group, similarity_groups = [], []
|
||||
if model == '0':
|
||||
for feature1 in features1:
|
||||
for feature2 in features2[0]:
|
||||
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||
similarity_group.append(similarity)
|
||||
similarity_groups.append(max(similarity_group))
|
||||
similarity_group = []
|
||||
return sum(similarity_groups)/len(similarity_groups)
|
||||
|
||||
elif model == '1':
|
||||
feature2 = features2[0]
|
||||
for feature1 in features1:
|
||||
for num in range(len(feature2)):
|
||||
similarity = np.dot(feature1, feature2[num]) / (np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||
similarity_group.append(similarity)
|
||||
similarity_groups.append(sum(similarity_group) / len(similarity_group))
|
||||
similarity_group = []
|
||||
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
|
||||
if len(similarity_groups) == 0:
|
||||
return -1
|
||||
return sum(similarity_groups)/len(similarity_groups)
|
||||
elif model == '2':
|
||||
feature2 = features2[0]
|
||||
for feature1 in features1:
|
||||
for num in range(len(feature2)):
|
||||
similarity = np.dot(feature1, feature2[num]) / (np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||
similarity_group.append(similarity)
|
||||
return max(similarity_group)
|
||||
|
||||
|
||||
|
||||
def get_barcode_feature(data):
|
||||
barcode = data['key']
|
||||
features = data['value']
|
||||
return [barcode] * len(features), features
|
||||
|
||||
|
||||
def analysis_file(file_path):
|
||||
"""
|
||||
:param file_path:
|
||||
:return:
|
||||
"""
|
||||
barcodes, features = [], []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
for dic in data['total']:
|
||||
barcode, feature = get_barcode_feature(dic)
|
||||
barcodes.append(barcode)
|
||||
features.append(feature)
|
||||
return barcodes, features
|
||||
|
||||
|
||||
def create_base_index(index_file_pth=None,
|
||||
barcodes=None,
|
||||
features=None,
|
||||
save_index_name=None):
|
||||
index = create_index()
|
||||
if index_file_pth is not None:
|
||||
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
|
||||
save_index_name = index_file_pth.split('json')[0] + 'data'
|
||||
barcodes, features = analysis_file(index_file_pth)
|
||||
else:
|
||||
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
|
||||
for barcode, feature in zip(barcodes, features):
|
||||
index.add(np.array(barcode), np.array(feature))
|
||||
index.save(save_index_name)
|
||||
|
||||
def get_feature_index(index_file_pth=None,
|
||||
barcodes=None):
|
||||
assert index_file_pth is not None, 'index_file_pth must be not None'
|
||||
index = Index.restore(index_file_pth, view=True)
|
||||
feature_lists = index.get(np.array(barcodes))
|
||||
print("memory {} size {}".format(index.memory_usage, index.size))
|
||||
return feature_lists
|
||||
|
||||
def search_in_index(query=None,
|
||||
barcode=None, # barcode -> int or np.ndarray
|
||||
index_name=None,
|
||||
temp_index=False, # 是否为临时库
|
||||
model='0',
|
||||
):
|
||||
if temp_index:
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
if barcode is not None: # 1:1对比测试
|
||||
feature_lists = index.get(np.array(barcode))
|
||||
results = compare_feature(query, feature_lists)
|
||||
else:
|
||||
results = index.search(query, count=5)
|
||||
return results
|
||||
else: # 标准库
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
if barcode is not None: # 1:1对比测试
|
||||
feature_lists = index.get(np.array(barcode))
|
||||
results = compare_feature(query, feature_lists, model)
|
||||
else:
|
||||
results = index.search(query, count=10)
|
||||
return results
|
||||
|
||||
def delete_index(index_name=None, key=None, index=None):
|
||||
assert key is not None, 'key must be not None'
|
||||
if index is None:
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
index.remove(index_name)
|
||||
else:
|
||||
index.remove(key)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# index_file_pth = '../search_library/data_0923.json'
|
||||
# create_base_index(index_file_pth)
|
||||
|
||||
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
|
||||
# # index_file_pth = '../search_library/data_10_normal_0718.index'
|
||||
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
|
||||
|
||||
# check index data file
|
||||
index_file_pth = '../search_library/data_0923.data'
|
||||
# # get_feature_index(index_file_pth, ['6901070602818'])
|
||||
get_feature_index(index_file_pth, ['6934230050105'])
|
||||
|
Reference in New Issue
Block a user