Files
ieemoo-ai-contrast/tools/operate_usearch.py
2025-06-11 15:23:50 +08:00

233 lines
9.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
from usearch.index import Index
import json
import struct
def create_index():
index = Index(
ndim=256,
metric='cos',
# dtype='f32',
dtype='f16',
connectivity=32,
expansion_add=40, # 128,
expansion_search=10, # 64,
multi=True
)
return index
def compare_feature(features1, features2, model='1'):
"""
:param model 比对策略
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
'1':带对比的所有相似度的均值
'2':比对1:1的最大值
:param feature1:
:param feature2:
:return:
"""
similarity_group, similarity_groups = [], []
if model == '0':
for feature1 in features1:
for feature2 in features2[0]:
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
similarity_group.append(similarity)
similarity_groups.append(max(similarity_group))
similarity_group = []
return sum(similarity_groups) / len(similarity_groups)
elif model == '1':
feature2 = features2[0]
for feature1 in features1:
for num in range(len(feature2)):
similarity = np.dot(feature1, feature2[num]) / (
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
similarity_group.append(similarity)
similarity_groups.append(sum(similarity_group) / len(similarity_group))
similarity_group = []
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
if len(similarity_groups) == 0:
return -1
return sum(similarity_groups) / len(similarity_groups)
elif model == '2':
feature2 = features2[0]
for feature1 in features1:
for num in range(len(feature2)):
similarity = np.dot(feature1, feature2[num]) / (
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
similarity_group.append(similarity)
return max(similarity_group)
def get_barcode_feature(data):
barcode = data['key']
features = data['value']
return [barcode] * len(features), features
def analysis_file(file_path):
"""
:param file_path:
:return:
"""
barcodes, features = [], []
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for dic in data['total']:
barcode, feature = get_barcode_feature(dic)
barcodes.append(barcode)
features.append(feature)
return barcodes, features
def create_base_index(index_file_pth=None,
barcodes=None,
features=None,
save_index_name=None):
index = create_index()
if index_file_pth is not None:
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
save_index_name = index_file_pth.split('json')[0] + 'data'
barcodes, features = analysis_file(index_file_pth)
else:
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
for barcode, feature in zip(barcodes, features):
try:
index.add(np.array(barcode), np.array(feature))
except Exception as e:
print(e)
continue
index.save(save_index_name)
def get_feature_index(index_file_pth=None,
barcodes=None):
assert index_file_pth is not None, 'index_file_pth must be not None'
index = Index.restore(index_file_pth, view=True)
feature_lists = index.get(np.array(barcodes))
print("memory {} size {}".format(index.memory_usage, index.size))
print("feature_lists {}".format(feature_lists))
return feature_lists
def search_in_index(query=None,
barcode=None, # barcode -> int or np.ndarray
index_name=None,
temp_index=False, # 是否为临时库
model='0',
):
if temp_index:
assert index_name is not None, 'index_name must be not None'
index = Index.restore(index_name, view=True)
if barcode is not None: # 1:1对比测试
feature_lists = index.get(np.array(barcode))
results = compare_feature(query, feature_lists)
else:
results = index.search(query, count=5)
return results
else: # 标准库
assert index_name is not None, 'index_name must be not None'
index = Index.restore(index_name, view=True)
if barcode is not None: # 1:1对比测试
feature_lists = index.get(np.array(barcode))
results = compare_feature(query, feature_lists, model)
else:
results = index.search(query, count=10)
return results
def delete_index(index_name=None, key=None, index=None):
assert key is not None, 'key must be not None'
if index is None:
assert index_name is not None, 'index_name must be not None'
index = Index.restore(index_name, view=True)
index.remove(index_name)
else:
index.remove(key)
from scipy.spatial.distance import cdist
def compute_similarity_matrix(featurelists1, featurelists2):
"""计算图片之间的余弦相似度矩阵"""
# 计算所有向量对之间的余弦相似度
cosine_similarities = 1 - cdist(featurelists1, featurelists2, metric='cosine')
cosine_similarities = np.around(cosine_similarities, decimals=3)
return cosine_similarities
def check_usearch_json_diff(index_file_pth, json_file_pth):
json_features = None
feature_lists = get_feature_index(index_file_pth, ['6923644272159'])
with open(json_file_pth, 'r') as json_file:
json_data = json.load(json_file)
for data in json_data['total']:
if data['key'] == '6923644272159':
json_features = data['value']
json_features = np.array(json_features)
feature_lists = np.array(feature_lists[0])
compute_similarity_matrix(json_features, feature_lists)
def write_binary_file(filename, datas):
with open(filename, 'wb') as f:
# 先写入数据中的key数量为C++读取提供便利)
key_count = len(datas)
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型4字节
for data in datas:
key = data['key']
feats = data['value']
key_bytes = key.encode('utf-8')
key_len = len(key)
length_byte = struct.pack('<B', key_len)
f.write(length_byte)
# f.write(struct.pack('Q', len(key_bytes)))
f.write(key_bytes)
value_count = len(feats)
f.write(struct.pack('I', (value_count * 256)))
# 遍历字典写入每个key及其对应的浮点数值列表
for values in feats:
# 写入每个浮点数值(保留小数点后六位)
for value in values:
# 使用'f'格式单精度浮点4字节并四舍五入保留六位小数
value_half = np.float16(value)
# print(value_half.tobytes())
f.write(value_half.tobytes())
def create_binary_file(json_path, flag=True):
# 1. 打开JSON文件
with open(json_path, 'r', encoding='utf-8') as file:
# 2. 读取并解析JSON文件内容
data = json.load(file)
if flag:
for flag, values in data.items():
# 逐个写入values中的每个值保留小数点后六位每个值占一行
write_binary_file(index_file_pth.replace('json', 'bin'), values)
else:
write_binary_file(json_path.replace('.json', '.bin'), [data])
def create_binary_files(index_file_pth):
if os.path.isfile(index_file_pth):
create_binary_file(index_file_pth)
else:
for name in os.listdir(index_file_pth):
jsonpth = os.sep.join([index_file_pth, name])
create_binary_file(jsonpth, False)
if __name__ == '__main__':
# index_file_pth = '../data/feature_json' # 生成二进制文件 多文件
index_file_pth = '../search_library/yunhedian_30-04.json'
# create_base_index(index_file_pth) # 生成usearch文件
create_binary_files(index_file_pth) # 生成二进制文件 多文件
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
# # index_file_pth = '../search_library/data_10_normal_0718.index'
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
# # check index data file
# index_file_pth = '../search_library/data_zhanting.data'
# # # get_feature_index(index_file_pth, ['6901070602818'])
# get_feature_index(index_file_pth, ['6923644272159'])
# index_file_pth = '../search_library/data_zhanting.data'
# json_file_pth = '../search_library/data_zhanting.json'
# check_usearch_json_diff(index_file_pth, json_file_pth)