233 lines
9.0 KiB
Python
233 lines
9.0 KiB
Python
import os
|
||
import numpy as np
|
||
from usearch.index import Index
|
||
import json
|
||
import struct
|
||
|
||
|
||
def create_index():
|
||
index = Index(
|
||
ndim=256,
|
||
metric='cos',
|
||
# dtype='f32',
|
||
dtype='f16',
|
||
connectivity=32,
|
||
expansion_add=40, # 128,
|
||
expansion_search=10, # 64,
|
||
multi=True
|
||
)
|
||
return index
|
||
|
||
|
||
def compare_feature(features1, features2, model='1'):
|
||
"""
|
||
:param model 比对策略
|
||
'0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值
|
||
'1':带对比的所有相似度的均值
|
||
'2':比对1:1的最大值
|
||
:param feature1:
|
||
:param feature2:
|
||
:return:
|
||
"""
|
||
similarity_group, similarity_groups = [], []
|
||
if model == '0':
|
||
for feature1 in features1:
|
||
for feature2 in features2[0]:
|
||
similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
||
similarity_group.append(similarity)
|
||
similarity_groups.append(max(similarity_group))
|
||
similarity_group = []
|
||
return sum(similarity_groups) / len(similarity_groups)
|
||
|
||
elif model == '1':
|
||
feature2 = features2[0]
|
||
for feature1 in features1:
|
||
for num in range(len(feature2)):
|
||
similarity = np.dot(feature1, feature2[num]) / (
|
||
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||
similarity_group.append(similarity)
|
||
similarity_groups.append(sum(similarity_group) / len(similarity_group))
|
||
similarity_group = []
|
||
# return sum(similarity_groups)/len(similarity_groups), max(similarity_groups)
|
||
if len(similarity_groups) == 0:
|
||
return -1
|
||
return sum(similarity_groups) / len(similarity_groups)
|
||
elif model == '2':
|
||
feature2 = features2[0]
|
||
for feature1 in features1:
|
||
for num in range(len(feature2)):
|
||
similarity = np.dot(feature1, feature2[num]) / (
|
||
np.linalg.norm(feature1) * np.linalg.norm(feature2[num]))
|
||
similarity_group.append(similarity)
|
||
return max(similarity_group)
|
||
|
||
def get_barcode_feature(data):
|
||
barcode = data['key']
|
||
features = data['value']
|
||
return [barcode] * len(features), features
|
||
|
||
|
||
def analysis_file(file_path):
|
||
"""
|
||
:param file_path:
|
||
:return:
|
||
"""
|
||
barcodes, features = [], []
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
for dic in data['total']:
|
||
barcode, feature = get_barcode_feature(dic)
|
||
barcodes.append(barcode)
|
||
features.append(feature)
|
||
return barcodes, features
|
||
|
||
|
||
def create_base_index(index_file_pth=None,
|
||
barcodes=None,
|
||
features=None,
|
||
save_index_name=None):
|
||
index = create_index()
|
||
if index_file_pth is not None:
|
||
# save_index_name = index_file_pth.split('json')[0] + 'usearch'
|
||
save_index_name = index_file_pth.split('json')[0] + 'data'
|
||
barcodes, features = analysis_file(index_file_pth)
|
||
else:
|
||
assert barcodes is not None and features is not None, 'barcodes and features must be not None'
|
||
for barcode, feature in zip(barcodes, features):
|
||
try:
|
||
index.add(np.array(barcode), np.array(feature))
|
||
except Exception as e:
|
||
print(e)
|
||
continue
|
||
index.save(save_index_name)
|
||
|
||
|
||
def get_feature_index(index_file_pth=None,
|
||
barcodes=None):
|
||
assert index_file_pth is not None, 'index_file_pth must be not None'
|
||
index = Index.restore(index_file_pth, view=True)
|
||
feature_lists = index.get(np.array(barcodes))
|
||
print("memory {} size {}".format(index.memory_usage, index.size))
|
||
print("feature_lists {}".format(feature_lists))
|
||
return feature_lists
|
||
|
||
|
||
def search_in_index(query=None,
|
||
barcode=None, # barcode -> int or np.ndarray
|
||
index_name=None,
|
||
temp_index=False, # 是否为临时库
|
||
model='0',
|
||
):
|
||
if temp_index:
|
||
assert index_name is not None, 'index_name must be not None'
|
||
index = Index.restore(index_name, view=True)
|
||
if barcode is not None: # 1:1对比测试
|
||
feature_lists = index.get(np.array(barcode))
|
||
results = compare_feature(query, feature_lists)
|
||
else:
|
||
results = index.search(query, count=5)
|
||
return results
|
||
else: # 标准库
|
||
assert index_name is not None, 'index_name must be not None'
|
||
index = Index.restore(index_name, view=True)
|
||
if barcode is not None: # 1:1对比测试
|
||
feature_lists = index.get(np.array(barcode))
|
||
results = compare_feature(query, feature_lists, model)
|
||
else:
|
||
results = index.search(query, count=10)
|
||
return results
|
||
|
||
|
||
def delete_index(index_name=None, key=None, index=None):
|
||
assert key is not None, 'key must be not None'
|
||
if index is None:
|
||
assert index_name is not None, 'index_name must be not None'
|
||
index = Index.restore(index_name, view=True)
|
||
index.remove(index_name)
|
||
else:
|
||
index.remove(key)
|
||
|
||
from scipy.spatial.distance import cdist
|
||
def compute_similarity_matrix(featurelists1, featurelists2):
|
||
"""计算图片之间的余弦相似度矩阵"""
|
||
# 计算所有向量对之间的余弦相似度
|
||
cosine_similarities = 1 - cdist(featurelists1, featurelists2, metric='cosine')
|
||
cosine_similarities = np.around(cosine_similarities, decimals=3)
|
||
return cosine_similarities
|
||
|
||
def check_usearch_json_diff(index_file_pth, json_file_pth):
|
||
json_features = None
|
||
feature_lists = get_feature_index(index_file_pth, ['6923644272159'])
|
||
with open(json_file_pth, 'r') as json_file:
|
||
json_data = json.load(json_file)
|
||
for data in json_data['total']:
|
||
if data['key'] == '6923644272159':
|
||
json_features = data['value']
|
||
json_features = np.array(json_features)
|
||
feature_lists = np.array(feature_lists[0])
|
||
compute_similarity_matrix(json_features, feature_lists)
|
||
|
||
|
||
def write_binary_file(filename, datas):
|
||
with open(filename, 'wb') as f:
|
||
# 先写入数据中的key数量(为C++读取提供便利)
|
||
key_count = len(datas)
|
||
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||
|
||
for data in datas:
|
||
key = data['key']
|
||
feats = data['value']
|
||
key_bytes = key.encode('utf-8')
|
||
key_len = len(key)
|
||
length_byte = struct.pack('<B', key_len)
|
||
f.write(length_byte)
|
||
# f.write(struct.pack('Q', len(key_bytes)))
|
||
f.write(key_bytes)
|
||
value_count = len(feats)
|
||
f.write(struct.pack('I', (value_count * 256)))
|
||
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||
for values in feats:
|
||
# 写入每个浮点数值(保留小数点后六位)
|
||
for value in values:
|
||
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||
value_half = np.float16(value)
|
||
# print(value_half.tobytes())
|
||
f.write(value_half.tobytes())
|
||
def create_binary_file(json_path, flag=True):
|
||
# 1. 打开JSON文件
|
||
with open(json_path, 'r', encoding='utf-8') as file:
|
||
# 2. 读取并解析JSON文件内容
|
||
data = json.load(file)
|
||
if flag:
|
||
for flag, values in data.items():
|
||
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||
write_binary_file(index_file_pth.replace('json', 'bin'), values)
|
||
else:
|
||
write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||
|
||
def create_binary_files(index_file_pth):
|
||
if os.path.isfile(index_file_pth):
|
||
create_binary_file(index_file_pth)
|
||
else:
|
||
for name in os.listdir(index_file_pth):
|
||
jsonpth = os.sep.join([index_file_pth, name])
|
||
create_binary_file(jsonpth, False)
|
||
|
||
if __name__ == '__main__':
|
||
# index_file_pth = '../data/feature_json' # 生成二进制文件 多文件
|
||
index_file_pth = '../search_library/yunhedian_30-04.json'
|
||
# create_base_index(index_file_pth) # 生成usearch文件
|
||
create_binary_files(index_file_pth) # 生成二进制文件 多文件
|
||
|
||
# index_file_pth = '../search_library/test_index_10_normal_0717.usearch'
|
||
# # index_file_pth = '../search_library/data_10_normal_0718.index'
|
||
# search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466')
|
||
|
||
# # check index data file
|
||
# index_file_pth = '../search_library/data_zhanting.data'
|
||
# # # get_feature_index(index_file_pth, ['6901070602818'])
|
||
# get_feature_index(index_file_pth, ['6923644272159'])
|
||
|
||
# index_file_pth = '../search_library/data_zhanting.data'
|
||
# json_file_pth = '../search_library/data_zhanting.json'
|
||
# check_usearch_json_diff(index_file_pth, json_file_pth) |