166 lines
4.5 KiB
Python
166 lines
4.5 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
Created on Wed Nov 20 11:17:29 2024
|
||
|
||
@author: ym
|
||
"""
|
||
|
||
import os
|
||
import cv2
|
||
import pickle
|
||
import numpy as np
|
||
from scipy.spatial.distance import cdist
|
||
import matplotlib.pyplot as plt
|
||
|
||
def save_imgpairs(barcode, imgpaths, matrix, savepath, thresh=(0.4, 0.6), ctype="intra"):
|
||
if ctype=="intra":
|
||
rows, cols = np.triu_indices(matrix.shape[0], k=1) # k=1 表示不包括对角线
|
||
mask = matrix[rows, cols] < thresh[1]
|
||
indices = list(zip(rows[mask], cols[mask]))
|
||
else:
|
||
rows, cols = np.where(matrix > thresh[0])
|
||
indices = list(zip(rows, cols))
|
||
|
||
|
||
if len(indices):
|
||
savepath = os.path.join(savepath, barcode)
|
||
if not os.path.exists(savepath):
|
||
os.makedirs (savepath)
|
||
|
||
|
||
for idx1, idx2 in indices:
|
||
if len(imgpaths) == 1:
|
||
img1 = cv2.imread(imgpaths[0][idx1])
|
||
img2 = cv2.imread(imgpaths[0][idx2])
|
||
elif len(imgpaths) == 2:
|
||
img1 = cv2.imread(imgpaths[0][idx1])
|
||
img2 = cv2.imread(imgpaths[1][idx2])
|
||
|
||
|
||
|
||
simi = matrix[idx1, idx2]
|
||
|
||
H1, W1 = img1.shape[:2]
|
||
H2, W2 = img2.shape[:2]
|
||
H, W = max((H1, H2)), max((W1, W2))
|
||
img = np.ones((H, 2*W, 3), dtype=np.uint8) *np.array([255, 128, 128])
|
||
|
||
img[0:H1, 0:W1, :] = img1
|
||
img[0:H2, (2*W-W2):, :] = img2
|
||
|
||
text = f"sim: {simi:.2f}"
|
||
org = (10, H-10)
|
||
cv2.putText(img, text, org, fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.75,
|
||
color=(0, 0, 255), thickness=2, lineType=cv2.LINE_AA)
|
||
imgpath = os.path.join(savepath, f"{simi:.2f}_{barcode}_{idx1}_{idx2}.png")
|
||
cv2.imwrite(imgpath, img)
|
||
|
||
|
||
|
||
def feat_analysis(featpath):
|
||
'''
|
||
标准特征集中样本类内、类间相似度分布
|
||
'''
|
||
|
||
|
||
|
||
savepath = r"D:\exhibition\result\stdfeat"
|
||
|
||
InterThresh = (0.4, 0.6)
|
||
|
||
featDict, features= [], []
|
||
for filename in os.listdir(featpath):
|
||
file, ext = os.path.splitext(filename)
|
||
if ext != ".pickle": continue
|
||
|
||
filepath = os.path.join(featpath, filename)
|
||
with open(filepath, 'rb') as f:
|
||
bpDict = pickle.load(f)
|
||
|
||
feat = bpDict["feats_ft32"]
|
||
|
||
featDict.append(bpDict)
|
||
features.append(feat)
|
||
N = len(features)
|
||
|
||
simMatrix = []
|
||
intra_simi = np.empty(0)
|
||
low_simi_index = {}
|
||
for i, feats in enumerate(features):
|
||
matrix = 1 - cdist(feats, feats, 'cosine')
|
||
|
||
simMatrix.append(matrix)
|
||
|
||
'''提取相似矩阵上三角元素'''
|
||
rows, cols = np.triu_indices(matrix.shape[0], k=1) # k=1 表示不包括对角线
|
||
upper_tri= matrix[rows, cols]
|
||
intra_simi = np.concatenate((intra_simi, upper_tri))
|
||
|
||
'''保存相似度小于阈值的图像对'''
|
||
barcode = featDict[i]["barcode"]
|
||
imgpaths = featDict[i]["imgpaths"]
|
||
# save_imgpairs(barcode, [imgpaths], matrix, savepath, InterThresh, "intra")
|
||
print(f"{barcode} have done!")
|
||
|
||
Matrix = np.zeros((N, N))
|
||
inter_bcds = []
|
||
inter_simi = np.empty(0)
|
||
for i, feati in enumerate(features):
|
||
bcdi = featDict[i]["barcode"]
|
||
imgpathi = featDict[i]["imgpaths"]
|
||
for j, featj in enumerate(features):
|
||
bcdj = featDict[j]["barcode"]
|
||
imgpathj = featDict[j]["imgpaths"]
|
||
|
||
matrix = 1 - cdist(feati, featj, 'cosine')
|
||
|
||
inter_bcds.append((i, j, bcdi, bcdj))
|
||
Matrix[i, j] = np.mean(matrix)
|
||
if j>i:
|
||
bcd_ij = bcdi+'_'+bcdj
|
||
# save_imgpairs(bcd_ij, [imgpathi, imgpathj], matrix, savepath, InterThresh, "inter")
|
||
inter_simi = np.concatenate((inter_simi, matrix.ravel()))
|
||
|
||
print(f"{bcd_ij} have done!")
|
||
|
||
fig, axs = plt.subplots(2, 1)
|
||
axs[0].hist(intra_simi, bins=100, color='blue', edgecolor='black', alpha=0.7)
|
||
axs[0].set_xlim(0, 1)
|
||
axs[0].set_xlabel('Performance')
|
||
axs[0].set_title("inter similarity")
|
||
|
||
axs[1].hist(inter_simi, bins=100, color='green', edgecolor='black', alpha=0.7)
|
||
axs[1].set_xlim(0, 1)
|
||
axs[1].set_xlabel('Performance')
|
||
axs[1].set_title("inter similarity")
|
||
|
||
|
||
|
||
|
||
|
||
print("Done!")
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def main():
|
||
stdpath = r"D:\exhibition\dataset\feats"
|
||
|
||
feat_analysis(stdpath)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|