import os import numpy as np from usearch.index import Index import json import statistics def create_index(): index = Index( ndim=256, metric='cos', # dtype='f32', dtype='f16', connectivity=32, expansion_add=40,#128, expansion_search=10,#64, multi=True ) return index def compare_feature(features1, features2, model = '1'): """ :param model 比对策略 '0':模拟一个轨迹的图像(所有的图像、或者挑选的若干图像)与标准库,先求每个图片与标准库的最大值,再求所有图片对应最大值的均值 '1':带对比的所有相似度的均值 '2':比对1:1的最大值 :param feature1: :param feature2: :return: """ similarity_group, similarity_groups = [], [] if model == '0': for feature1 in features1: for feature2 in features2[0]: similarity = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2)) similarity_group.append(similarity) similarity_groups.append(max(similarity_group)) similarity_group = [] return sum(similarity_groups)/len(similarity_groups) elif model == '1': feature2 = features2[0] for feature1 in features1: for num in range(len(feature2)): similarity = np.dot(feature1, feature2[num]) / (np.linalg.norm(feature1) * np.linalg.norm(feature2[num])) similarity_group.append(similarity) similarity_groups.append(sum(similarity_group) / len(similarity_group)) similarity_group = [] # return sum(similarity_groups)/len(similarity_groups), max(similarity_groups) if len(similarity_groups) == 0: return -1 return sum(similarity_groups)/len(similarity_groups) elif model == '2': feature2 = features2[0] for feature1 in features1: for num in range(len(feature2)): similarity = np.dot(feature1, feature2[num]) / (np.linalg.norm(feature1) * np.linalg.norm(feature2[num])) similarity_group.append(similarity) return max(similarity_group) def get_barcode_feature(data): barcode = data['key'] features = data['value'] return [barcode] * len(features), features def analysis_file(file_path): """ :param file_path: :return: """ barcodes, features = [], [] with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) for dic in data['total']: barcode, feature = get_barcode_feature(dic) barcodes.append(barcode) features.append(feature) return barcodes, features def create_base_index(index_file_pth=None, barcodes=None, features=None, save_index_name=None): index = create_index() if index_file_pth is not None: # save_index_name = index_file_pth.split('json')[0] + 'usearch' save_index_name = index_file_pth.split('json')[0] + 'data' barcodes, features = analysis_file(index_file_pth) else: assert barcodes is not None and features is not None, 'barcodes and features must be not None' for barcode, feature in zip(barcodes, features): index.add(np.array(barcode), np.array(feature)) index.save(save_index_name) def get_feature_index(index_file_pth=None, barcodes=None): assert index_file_pth is not None, 'index_file_pth must be not None' index = Index.restore(index_file_pth, view=True) feature_lists = index.get(np.array(barcodes)) print("memory {} size {}".format(index.memory_usage, index.size)) return feature_lists def search_in_index(query=None, barcode=None, # barcode -> int or np.ndarray index_name=None, temp_index=False, # 是否为临时库 model='0', ): if temp_index: assert index_name is not None, 'index_name must be not None' index = Index.restore(index_name, view=True) if barcode is not None: # 1:1对比测试 feature_lists = index.get(np.array(barcode)) results = compare_feature(query, feature_lists) else: results = index.search(query, count=5) return results else: # 标准库 assert index_name is not None, 'index_name must be not None' index = Index.restore(index_name, view=True) if barcode is not None: # 1:1对比测试 feature_lists = index.get(np.array(barcode)) results = compare_feature(query, feature_lists, model) else: results = index.search(query, count=10) return results def delete_index(index_name=None, key=None, index=None): assert key is not None, 'key must be not None' if index is None: assert index_name is not None, 'index_name must be not None' index = Index.restore(index_name, view=True) index.remove(index_name) else: index.remove(key) if __name__ == '__main__': # index_file_pth = '../search_library/data_0923.json' # create_base_index(index_file_pth) # index_file_pth = '../search_library/test_index_10_normal_0717.usearch' # # index_file_pth = '../search_library/data_10_normal_0718.index' # search_in_index(query='693', index_name=index_file_pth, barcode='6934024590466') # check index data file index_file_pth = '../search_library/data_0923.data' # # get_feature_index(index_file_pth, ['6901070602818']) get_feature_index(index_file_pth, ['6934230050105'])