104 lines
4.2 KiB
Python
104 lines
4.2 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
|
|
from ytracking.track_ import *
|
|
from tools.Interface import AiInterface, AiClass
|
|
from tools.operate_usearch import create_base_index, search_in_index
|
|
from tools.initModel import models
|
|
from imgcompare import get_feature_list, compute_similarity_matrix
|
|
import pickle
|
|
models.initModel()
|
|
ai_obj = AiClass()
|
|
|
|
|
|
def get_img_lists(pth):
|
|
imglist, imglists = [], []
|
|
for root, dirs, files in os.walk(pth):
|
|
if not any(dirs):
|
|
for file in files:
|
|
if file.endswith('.jpg'):
|
|
imglist.append(os.sep.join([root, file]))
|
|
imglists.append(imglist)
|
|
imglist = []
|
|
return imglists
|
|
|
|
def get_standard_image(cosine_similarities, similarity_threshold=0.6):
|
|
"""
|
|
:param cosine_similarities:
|
|
:return:
|
|
"""
|
|
target_indexs = []
|
|
max_similarity = {}
|
|
mask = (cosine_similarities > similarity_threshold)
|
|
counts = mask.sum(axis=1)
|
|
for key in range(counts.shape[0]):
|
|
max_similarity[key] = counts[key]
|
|
sorted_dict_desc = dict(sorted(max_similarity.items(), key=lambda item: item[1], reverse=True))
|
|
keys = list(sorted_dict_desc.keys())
|
|
while len(keys) > 10:
|
|
target_indexs.append(keys[0])
|
|
single_line = cosine_similarities[keys[0], :]
|
|
rows = np.where((single_line > similarity_threshold))
|
|
if len(rows[0]) < 2:
|
|
break
|
|
for row in rows[0]:
|
|
try:
|
|
keys.remove(row)
|
|
except Exception as e:
|
|
continue
|
|
# print(target_indexs)
|
|
return target_indexs
|
|
|
|
def create_feature_library(pth, save_index_name, index_file_pth=None):
|
|
target_feature_lists, target_barcode_lists = [], []
|
|
imglists = get_img_lists(pth)
|
|
for imglist in imglists:
|
|
feature_list = get_feature_list(imglist, False)
|
|
cosine_similarities = compute_similarity_matrix(feature_list)
|
|
target_indexs = get_standard_image(cosine_similarities)
|
|
target_feature_lists.append([feature_list[i] for i in target_indexs])
|
|
target_barcode_lists.append([os.path.basename(imglist[i]).split('_')[0] for i in target_indexs])
|
|
create_base_index(save_index_name=save_index_name,
|
|
barcodes=target_barcode_lists,
|
|
features=target_feature_lists,
|
|
index_file_pth=index_file_pth)
|
|
with open('search_library/target_barcode_lists.pkl', 'wb') as f:
|
|
pickle.dump(target_barcode_lists, f)
|
|
|
|
def search_top_in_index(test_image_pth, index_name): #1:N
|
|
s_barcode, s_similarity = [], []
|
|
img_lists = [os.sep.join([test_image_pth, name]) for name in os.listdir(test_image_pth)]
|
|
feature_lists = get_feature_list(img_lists, False)
|
|
for feature in feature_lists:
|
|
result = search_in_index(query=np.array(feature), index_name=index_name)
|
|
s_barcode.append(result.keys)
|
|
s_similarity.append(1-result.distances)
|
|
s_barcode = np.array(s_barcode)
|
|
s_similarity = np.array(s_similarity)
|
|
return s_barcode, s_similarity
|
|
|
|
def search_one_in_index(test_image_pth, index_name): # 1:1
|
|
barcodes = [int(os.path.basename(name).split('_')[0]) for name in os.listdir(test_image_pth)]
|
|
barcodes = list(set(barcodes))
|
|
# barcodes = ['6934364805640']
|
|
img_lists = [os.sep.join([test_image_pth, name]) for name in os.listdir(test_image_pth)]
|
|
feature_lists = get_feature_list(img_lists, False)
|
|
result = search_in_index(barcode=barcodes,
|
|
query=feature_lists,
|
|
index_name=index_name,
|
|
temp_index=False)
|
|
print(feature_lists)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pth = 'imageQualityData/test_data'
|
|
save_index_name = 'search_library/test_index_10_simple_0717.usearch'
|
|
create_feature_library(pth,
|
|
save_index_name=save_index_name)
|
|
|
|
# test_images_pth = 'D:/Project/ieemoo/image_quality_assessment/imageQualityData/test_images'
|
|
# # index_name = 'D:/Project/ieemoo/image_quality_assessment/search_library/test_index_10_normal_0717.usearch'
|
|
# index_name = 'D:/Project/ieemoo/image_quality_assessment/search_library/test_index_10_simple_0717.usearch'
|
|
# # search_top_in_index(test_images_pth, index_name)
|
|
# search_one_in_index(test_images_pth, index_name) |