This commit is contained in:
lee
2025-06-11 15:23:50 +08:00
commit 37ecef40f7
79 changed files with 26981 additions and 0 deletions

233
tools/operate_usearch.py Normal file
View File

@ -0,0 +1,233 @@
import os
import numpy as np
from usearch.index import Index
import json
import struct
def create_index():
index = Index(
ndim=256,
metric='cos',
# dtype='f32',
dtype='f16',
connectivity=32,
expansion_add=40, # 128,
expansion_search=10, # 64,
multi=True
)
return index
def compare_feature(features1, features2, model='1'):
"""
:param model 比对策略
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
'1':带对比的所有相似度的均值
'2':比对1:1的最大值
:param feature1:
:param feature2:
:return:
"""
similarity_group, similarity_groups = [], []
if model == '0':
for feature1 in features1:
for feature2 in features2[0]:
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
similarity_group.append(similarity)
similarity_groups.append(max(similarity_group))
similarity_group = []
return sum(similarity_groups) / len(similarity_groups)
elif model == '1':
feature2 = features2[0]
for feature1 in features1:
for num in range(len(feature2)):
similarity = np.dot(feature1, feature2[num]) / (
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
similarity_group.append(similarity)
similarity_groups.append(sum(similarity_group) / len(similarity_group))
similarity_group = []
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
if len(similarity_groups) == 0:
return -1
return sum(similarity_groups) / len(similarity_groups)
elif model == '2':
feature2 = features2[0]
for feature1 in features1:
for num in range(len(feature2)):
similarity = np.dot(feature1, feature2[num]) / (
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
similarity_group.append(similarity)
return max(similarity_group)
def get_barcode_feature(data):
barcode = data['key']
features = data['value']
return [barcode] * len(features), features
def analysis_file(file_path):
"""
:param file_path:
:return:
"""
barcodes, features = [], []
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for dic in data['total']:
barcode, feature = get_barcode_feature(dic)
barcodes.append(barcode)
features.append(feature)
return barcodes, features
def create_base_index(index_file_pth=None,
barcodes=None,
features=None,
save_index_name=None):
index = create_index()
if index_file_pth is not None:
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
save_index_name = index_file_pth.split('json')[0] + 'data'
barcodes, features = analysis_file(index_file_pth)
else:
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
for barcode, feature in zip(barcodes, features):
try:
index.add(np.array(barcode), np.array(feature))
except Exception as e:
print(e)
continue
index.save(save_index_name)
def get_feature_index(index_file_pth=None,
barcodes=None):
assert index_file_pth is not None, 'index_file_pth must be not None'
index = Index.restore(index_file_pth, view=True)
feature_lists = index.get(np.array(barcodes))
print("memory {} size {}".format(index.memory_usage, index.size))
print("feature_lists {}".format(feature_lists))
return feature_lists
def search_in_index(query=None,
barcode=None, # barcode -> int or np.ndarray
index_name=None,
temp_index=False, # 是否为临时库
model='0',
):
if temp_index:
assert index_name is not None, 'index_name must be not None'
index = Index.restore(index_name, view=True)
if barcode is not None: # 1:1对比测试
feature_lists = index.get(np.array(barcode))
results = compare_feature(query, feature_lists)
else:
results = index.search(query, count=5)
return results
else: # 标准库
assert index_name is not None, 'index_name must be not None'
index = Index.restore(index_name, view=True)
if barcode is not None: # 1:1对比测试
feature_lists = index.get(np.array(barcode))
results = compare_feature(query, feature_lists, model)
else:
results = index.search(query, count=10)
return results
def delete_index(index_name=None, key=None, index=None):
assert key is not None, 'key must be not None'
if index is None:
assert index_name is not None, 'index_name must be not None'
index = Index.restore(index_name, view=True)
index.remove(index_name)
else:
index.remove(key)
from scipy.spatial.distance import cdist
def compute_similarity_matrix(featurelists1, featurelists2):
"""计算图片之间的余弦相似度矩阵"""
# 计算所有向量对之间的余弦相似度
cosine_similarities = 1 - cdist(featurelists1, featurelists2, metric='cosine')
cosine_similarities = np.around(cosine_similarities, decimals=3)
return cosine_similarities
def check_usearch_json_diff(index_file_pth, json_file_pth):
json_features = None
feature_lists = get_feature_index(index_file_pth, ['6923644272159'])
with open(json_file_pth, 'r') as json_file:
json_data = json.load(json_file)
for data in json_data['total']:
if data['key'] == '6923644272159':
json_features = data['value']
json_features = np.array(json_features)
feature_lists = np.array(feature_lists[0])
compute_similarity_matrix(json_features, feature_lists)
def write_binary_file(filename, datas):
with open(filename, 'wb') as f:
# 先写入数据中的key数量为C++读取提供便利)
key_count = len(datas)
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型4字节
for data in datas:
key = data['key']
feats = data['value']
key_bytes = key.encode('utf-8')
key_len = len(key)
length_byte = struct.pack('<B', key_len)
f.write(length_byte)
# f.write(struct.pack('Q', len(key_bytes)))
f.write(key_bytes)
value_count = len(feats)
f.write(struct.pack('I', (value_count * 256)))
# 遍历字典写入每个key及其对应的浮点数值列表
for values in feats:
# 写入每个浮点数值(保留小数点后六位)
for value in values:
# 使用'f'格式单精度浮点4字节并四舍五入保留六位小数
value_half = np.float16(value)
# print(value_half.tobytes())
f.write(value_half.tobytes())
def create_binary_file(json_path, flag=True):
# 1. 打开JSON文件
with open(json_path, 'r', encoding='utf-8') as file:
# 2. 读取并解析JSON文件内容
data = json.load(file)
if flag:
for flag, values in data.items():
# 逐个写入values中的每个值保留小数点后六位每个值占一行
write_binary_file(index_file_pth.replace('json', 'bin'), values)
else:
write_binary_file(json_path.replace('.json', '.bin'), [data])
def create_binary_files(index_file_pth):
if os.path.isfile(index_file_pth):
create_binary_file(index_file_pth)
else:
for name in os.listdir(index_file_pth):
jsonpth = os.sep.join([index_file_pth, name])
create_binary_file(jsonpth, False)
if __name__ == '__main__':
# index_file_pth = '../data/feature_json' # 生成二进制文件 多文件
index_file_pth = '../search_library/yunhedian_30-04.json'
# create_base_index(index_file_pth) # 生成usearch文件
create_binary_files(index_file_pth) # 生成二进制文件 多文件
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
# # index_file_pth = '../search_library/data_10_normal_0718.index'
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
# # check index data file
# index_file_pth = '../search_library/data_zhanting.data'
# # # get_feature_index(index_file_pth, ['6901070602818'])
# get_feature_index(index_file_pth, ['6923644272159'])
# index_file_pth = '../search_library/data_zhanting.data'
# json_file_pth = '../search_library/data_zhanting.json'
# check_usearch_json_diff(index_file_pth, json_file_pth)