rebuild
This commit is contained in:
233
tools/operate_usearch.py
Normal file
233
tools/operate_usearch.py
Normal file
@ -0,0 +1,233 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from usearch.index import Index
|
||||
import json
|
||||
import struct
|
||||
|
||||
|
||||
def create_index():
|
||||
index = Index(
|
||||
ndim=256,
|
||||
metric='cos',
|
||||
# dtype='f32',
|
||||
dtype='f16',
|
||||
connectivity=32,
|
||||
expansion_add=40, # 128,
|
||||
expansion_search=10, # 64,
|
||||
multi=True
|
||||
)
|
||||
return index
|
||||
|
||||
|
||||
def compare_feature(features1, features2, model='1'):
|
||||
"""
|
||||
:param model 比对策略
|
||||
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
|
||||
'1':带对比的所有相似度的均值
|
||||
'2':比对1:1的最大值
|
||||
:param feature1:
|
||||
:param feature2:
|
||||
:return:
|
||||
"""
|
||||
similarity_group, similarity_groups = [], []
|
||||
if model == '0':
|
||||
for feature1 in features1:
|
||||
for feature2 in features2[0]:
|
||||
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||||
similarity_group.append(similarity)
|
||||
similarity_groups.append(max(similarity_group))
|
||||
similarity_group = []
|
||||
return sum(similarity_groups) / len(similarity_groups)
|
||||
|
||||
elif model == '1':
|
||||
feature2 = features2[0]
|
||||
for feature1 in features1:
|
||||
for num in range(len(feature2)):
|
||||
similarity = np.dot(feature1, feature2[num]) / (
|
||||
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||
similarity_group.append(similarity)
|
||||
similarity_groups.append(sum(similarity_group) / len(similarity_group))
|
||||
similarity_group = []
|
||||
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
|
||||
if len(similarity_groups) == 0:
|
||||
return -1
|
||||
return sum(similarity_groups) / len(similarity_groups)
|
||||
elif model == '2':
|
||||
feature2 = features2[0]
|
||||
for feature1 in features1:
|
||||
for num in range(len(feature2)):
|
||||
similarity = np.dot(feature1, feature2[num]) / (
|
||||
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||||
similarity_group.append(similarity)
|
||||
return max(similarity_group)
|
||||
|
||||
def get_barcode_feature(data):
|
||||
barcode = data['key']
|
||||
features = data['value']
|
||||
return [barcode] * len(features), features
|
||||
|
||||
|
||||
def analysis_file(file_path):
|
||||
"""
|
||||
:param file_path:
|
||||
:return:
|
||||
"""
|
||||
barcodes, features = [], []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
for dic in data['total']:
|
||||
barcode, feature = get_barcode_feature(dic)
|
||||
barcodes.append(barcode)
|
||||
features.append(feature)
|
||||
return barcodes, features
|
||||
|
||||
|
||||
def create_base_index(index_file_pth=None,
|
||||
barcodes=None,
|
||||
features=None,
|
||||
save_index_name=None):
|
||||
index = create_index()
|
||||
if index_file_pth is not None:
|
||||
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
|
||||
save_index_name = index_file_pth.split('json')[0] + 'data'
|
||||
barcodes, features = analysis_file(index_file_pth)
|
||||
else:
|
||||
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
|
||||
for barcode, feature in zip(barcodes, features):
|
||||
try:
|
||||
index.add(np.array(barcode), np.array(feature))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
index.save(save_index_name)
|
||||
|
||||
|
||||
def get_feature_index(index_file_pth=None,
|
||||
barcodes=None):
|
||||
assert index_file_pth is not None, 'index_file_pth must be not None'
|
||||
index = Index.restore(index_file_pth, view=True)
|
||||
feature_lists = index.get(np.array(barcodes))
|
||||
print("memory {} size {}".format(index.memory_usage, index.size))
|
||||
print("feature_lists {}".format(feature_lists))
|
||||
return feature_lists
|
||||
|
||||
|
||||
def search_in_index(query=None,
|
||||
barcode=None, # barcode -> int or np.ndarray
|
||||
index_name=None,
|
||||
temp_index=False, # 是否为临时库
|
||||
model='0',
|
||||
):
|
||||
if temp_index:
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
if barcode is not None: # 1:1对比测试
|
||||
feature_lists = index.get(np.array(barcode))
|
||||
results = compare_feature(query, feature_lists)
|
||||
else:
|
||||
results = index.search(query, count=5)
|
||||
return results
|
||||
else: # 标准库
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
if barcode is not None: # 1:1对比测试
|
||||
feature_lists = index.get(np.array(barcode))
|
||||
results = compare_feature(query, feature_lists, model)
|
||||
else:
|
||||
results = index.search(query, count=10)
|
||||
return results
|
||||
|
||||
|
||||
def delete_index(index_name=None, key=None, index=None):
|
||||
assert key is not None, 'key must be not None'
|
||||
if index is None:
|
||||
assert index_name is not None, 'index_name must be not None'
|
||||
index = Index.restore(index_name, view=True)
|
||||
index.remove(index_name)
|
||||
else:
|
||||
index.remove(key)
|
||||
|
||||
from scipy.spatial.distance import cdist
|
||||
def compute_similarity_matrix(featurelists1, featurelists2):
|
||||
"""计算图片之间的余弦相似度矩阵"""
|
||||
# 计算所有向量对之间的余弦相似度
|
||||
cosine_similarities = 1 - cdist(featurelists1, featurelists2, metric='cosine')
|
||||
cosine_similarities = np.around(cosine_similarities, decimals=3)
|
||||
return cosine_similarities
|
||||
|
||||
def check_usearch_json_diff(index_file_pth, json_file_pth):
|
||||
json_features = None
|
||||
feature_lists = get_feature_index(index_file_pth, ['6923644272159'])
|
||||
with open(json_file_pth, 'r') as json_file:
|
||||
json_data = json.load(json_file)
|
||||
for data in json_data['total']:
|
||||
if data['key'] == '6923644272159':
|
||||
json_features = data['value']
|
||||
json_features = np.array(json_features)
|
||||
feature_lists = np.array(feature_lists[0])
|
||||
compute_similarity_matrix(json_features, feature_lists)
|
||||
|
||||
|
||||
def write_binary_file(filename, datas):
|
||||
with open(filename, 'wb') as f:
|
||||
# 先写入数据中的key数量(为C++读取提供便利)
|
||||
key_count = len(datas)
|
||||
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||||
|
||||
for data in datas:
|
||||
key = data['key']
|
||||
feats = data['value']
|
||||
key_bytes = key.encode('utf-8')
|
||||
key_len = len(key)
|
||||
length_byte = struct.pack('<B', key_len)
|
||||
f.write(length_byte)
|
||||
# f.write(struct.pack('Q', len(key_bytes)))
|
||||
f.write(key_bytes)
|
||||
value_count = len(feats)
|
||||
f.write(struct.pack('I', (value_count * 256)))
|
||||
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||||
for values in feats:
|
||||
# 写入每个浮点数值(保留小数点后六位)
|
||||
for value in values:
|
||||
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||||
value_half = np.float16(value)
|
||||
# print(value_half.tobytes())
|
||||
f.write(value_half.tobytes())
|
||||
def create_binary_file(json_path, flag=True):
|
||||
# 1. 打开JSON文件
|
||||
with open(json_path, 'r', encoding='utf-8') as file:
|
||||
# 2. 读取并解析JSON文件内容
|
||||
data = json.load(file)
|
||||
if flag:
|
||||
for flag, values in data.items():
|
||||
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||||
write_binary_file(index_file_pth.replace('json', 'bin'), values)
|
||||
else:
|
||||
write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||||
|
||||
def create_binary_files(index_file_pth):
|
||||
if os.path.isfile(index_file_pth):
|
||||
create_binary_file(index_file_pth)
|
||||
else:
|
||||
for name in os.listdir(index_file_pth):
|
||||
jsonpth = os.sep.join([index_file_pth, name])
|
||||
create_binary_file(jsonpth, False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# index_file_pth = '../data/feature_json' # 生成二进制文件 多文件
|
||||
index_file_pth = '../search_library/yunhedian_30-04.json'
|
||||
# create_base_index(index_file_pth) # 生成usearch文件
|
||||
create_binary_files(index_file_pth) # 生成二进制文件 多文件
|
||||
|
||||
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
|
||||
# # index_file_pth = '../search_library/data_10_normal_0718.index'
|
||||
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
|
||||
|
||||
# # check index data file
|
||||
# index_file_pth = '../search_library/data_zhanting.data'
|
||||
# # # get_feature_index(index_file_pth, ['6901070602818'])
|
||||
# get_feature_index(index_file_pth, ['6923644272159'])
|
||||
|
||||
# index_file_pth = '../search_library/data_zhanting.data'
|
||||
# json_file_pth = '../search_library/data_zhanting.json'
|
||||
# check_usearch_json_diff(index_file_pth, json_file_pth)
|
Reference in New Issue
Block a user