import os import numpy as np from ytracking.track_ import * from tools.Interface import AiInterface, AiClass from tools.operate_usearch import create_base_index, search_in_index from tools.initModel import models from imgcompare import get_feature_list, compute_similarity_matrix import pickle models.initModel() ai_obj = AiClass() def get_img_lists(pth): imglist, imglists = [], [] for root, dirs, files in os.walk(pth): if not any(dirs): for file in files: if file.endswith('.jpg'): imglist.append(os.sep.join([root, file])) imglists.append(imglist) imglist = [] return imglists def get_standard_image(cosine_similarities, similarity_threshold=0.6): """ :param cosine_similarities: :return: """ target_indexs = [] max_similarity = {} mask = (cosine_similarities > similarity_threshold) counts = mask.sum(axis=1) for key in range(counts.shape[0]): max_similarity[key] = counts[key] sorted_dict_desc = dict(sorted(max_similarity.items(), key=lambda item: item[1], reverse=True)) keys = list(sorted_dict_desc.keys()) while len(keys) > 10: target_indexs.append(keys[0]) single_line = cosine_similarities[keys[0], :] rows = np.where((single_line > similarity_threshold)) if len(rows[0]) < 2: break for row in rows[0]: try: keys.remove(row) except Exception as e: continue # print(target_indexs) return target_indexs def create_feature_library(pth, save_index_name, index_file_pth=None): target_feature_lists, target_barcode_lists = [], [] imglists = get_img_lists(pth) for imglist in imglists: feature_list = get_feature_list(imglist, False) cosine_similarities = compute_similarity_matrix(feature_list) target_indexs = get_standard_image(cosine_similarities) target_feature_lists.append([feature_list[i] for i in target_indexs]) target_barcode_lists.append([os.path.basename(imglist[i]).split('_')[0] for i in target_indexs]) create_base_index(save_index_name=save_index_name, barcodes=target_barcode_lists, features=target_feature_lists, index_file_pth=index_file_pth) with open('search_library/target_barcode_lists.pkl', 'wb') as f: pickle.dump(target_barcode_lists, f) def search_top_in_index(test_image_pth, index_name): #1:N s_barcode, s_similarity = [], [] img_lists = [os.sep.join([test_image_pth, name]) for name in os.listdir(test_image_pth)] feature_lists = get_feature_list(img_lists, False) for feature in feature_lists: result = search_in_index(query=np.array(feature), index_name=index_name) s_barcode.append(result.keys) s_similarity.append(1-result.distances) s_barcode = np.array(s_barcode) s_similarity = np.array(s_similarity) return s_barcode, s_similarity def search_one_in_index(test_image_pth, index_name): # 1:1 barcodes = [int(os.path.basename(name).split('_')[0]) for name in os.listdir(test_image_pth)] barcodes = list(set(barcodes)) # barcodes = ['6934364805640'] img_lists = [os.sep.join([test_image_pth, name]) for name in os.listdir(test_image_pth)] feature_lists = get_feature_list(img_lists, False) result = search_in_index(barcode=barcodes, query=feature_lists, index_name=index_name, temp_index=False) print(feature_lists) if __name__ == '__main__': pth = 'imageQualityData/test_data' save_index_name = 'search_library/test_index_10_simple_0717.usearch' create_feature_library(pth, save_index_name=save_index_name) # test_images_pth = 'D:/Project/ieemoo/image_quality_assessment/imageQualityData/test_images' # # index_name = 'D:/Project/ieemoo/image_quality_assessment/search_library/test_index_10_normal_0717.usearch' # index_name = 'D:/Project/ieemoo/image_quality_assessment/search_library/test_index_10_simple_0717.usearch' # # search_top_in_index(test_images_pth, index_name) # search_one_in_index(test_images_pth, index_name)