Files
ieemoo-ai-review/detecttracking/contrast/select_subimgs.py
2025-01-22 13:16:44 +08:00

109 lines
3.0 KiB
Python

# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 13:58:13 2024
writting for selectting std subimgs to Wuhuaqi
@author: ym
"""
import os
import time
# import torch
import pickle
# import json
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
from feat_extract.config import config as conf
# from model import resnet18 as resnet18
from feat_extract.inference import FeatsInterface #, inference_image
IMG_FORMAT = ['.bmp', '.jpg', '.jpeg', '.png']
def gen_features(imgpath):
Encoder = FeatsInterface(conf)
imgs, imgnames = [], []
for filename in os.listdir(imgpath):
file, ext = os.path.splitext(filename)
if ext not in IMG_FORMAT: continue
fpath = os.path.join(imgpath, filename)
img = Image.open(fpath)
imgs.append(img)
filelist = file.split("_")
newname = "_".join([filelist[0],filelist[1], filelist[2], filelist[-3], filelist[-2], filelist[-1]])
# imgnames.append(newname)
imgnames.append(file)
features = Encoder.inference(imgs)
features /= np.linalg.norm(features, axis=1)[:, None]
return features, imgnames
def top_p_percent_indices(matrix, p):
"""
Finds the indices of the top p% largest elements in a 2D matrix.
Args:
matrix (np.ndarray): A 2D NumPy array.
p: int, 0-100
Returns:
List[Tuple[int, int]]: A list of indices (row, column) for the top 10% largest elements.
"""
# Flatten the matrix
flat_matrix = matrix.flatten()
# Calculate the threshold for the top 10%
num_elements = len(flat_matrix)
threshold_index = int(num_elements * 0.01*p) # Top 10%
threshold_index = max(1, threshold_index) # Ensure at least one element is considered
threshold_value = np.partition(flat_matrix, -threshold_index)[-threshold_index]
# Create a mask for elements >= threshold
mask = matrix >= threshold_value
# Get the indices of elements that satisfy the mask
indices = np.argwhere(mask)
return list(map(tuple, indices))
def main():
imgpath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase\6923555210479"
feats, imgnames = gen_features(imgpath)
n = len(feats)
matrix = 1 - cdist(feats, feats, 'cosine')
nmatrix = np.array([[matrix[i][j] for j in range(n) if i != j] for i in range(n)])
top_p_large_index = top_p_percent_indices(nmatrix, 1)
top_p_small_index = top_p_percent_indices(-1*nmatrix, 1)
simi_mean = np.mean(nmatrix, axis=1)
max_simi = np.max(nmatrix)
max_index = np.where(nmatrix==max_simi)
min_simi = np.min(nmatrix)
min_index = np.where(nmatrix==min_simi)
fig, ax = plt.subplots()
simils = [matrix[i][j] for j in range(n) for i in range(n) if j>i]
ax.hist(simils, bins=60, range=(-0.2, 1), edgecolor='black')
ax.set_xlim([-0.2, 1])
ax.set_title("Similarity")
print("done!")
if __name__ == '__main__':
main()