# -*- coding: utf-8 -*- """ Created on Fri Aug 9 10:36:45 2024 分析图像对间的相似度 @author: ym """ import os import cv2 import numpy as np import torch import sys from scipy.spatial.distance import cdist ''' 加载 LC 定义的模型形式''' from config import config as conf from model import resnet18 as resnet18 from test_ori import inference_image ##============ load resnet mdoel model = resnet18().to(conf.device) # model = nn.DataParallel(model).to(conf.device) model.load_state_dict(torch.load(conf.test_model, map_location=conf.device)) model.eval() print('load model {} '.format(conf.testbackbone)) IMG_FORMAT = ['.bmp', '.jpg', '.JPG', '.jpeg', '.png'] # ============================================================================= # ''' 加载REID中定义的模型形式''' # sys.path.append(r"D:\DetectTracking") # from tracking.trackers.reid.reid_interface import ReIDInterface # from tracking.trackers.reid.config import config as ReIDConfig # ReIDEncoder = ReIDInterface(ReIDConfig) # # def inference_image_ReID(images): # batch_patches = [] # patches = [] # for d, img1 in enumerate(images): # # # img = img1[:, :, ::-1].copy() # the model expects RGB inputs # patch = ReIDEncoder.transform(img) # # # patch = patch.to(device=self.device).half() # if str(ReIDEncoder.device) != "cpu": # patch = patch.to(device=ReIDEncoder.device).half() # else: # patch = patch.to(device=ReIDEncoder.device) # # patches.append(patch) # if (d + 1) % ReIDEncoder.batch_size == 0: # patches = torch.stack(patches, dim=0) # batch_patches.append(patches) # patches = [] # # if len(patches): # patches = torch.stack(patches, dim=0) # batch_patches.append(patches) # # features = np.zeros((0, ReIDEncoder.embedding_size)) # for patches in batch_patches: # pred = ReIDEncoder.model(patches) # pred[torch.isinf(pred)] = 1.0 # feat = pred.cpu().data.numpy() # features = np.vstack((features, feat)) # # return features # ============================================================================= def silimarity_compare(): imgpaths = r"D:\DetectTracking\contrast\images\2" filepaths = [] for root, dirs, filenames in os.walk(imgpaths): for filename in filenames: file, ext = os.path.splitext(filename) if ext not in IMG_FORMAT: continue file_path = os.path.join(root, filename) filepaths.append(file_path) feature = inference_image(filepaths, conf.test_transform, model, conf.device) feature /= np.linalg.norm(feature, axis=1)[:, None] similar = 1 - np.maximum(0.0, cdist(feature, feature, metric='cosine')) print("Done!") def similarity_compare_sequence(root_dir): ''' root_dir:包含 "subimgs"字段的文件夹中图像为 subimg子图 功能:相邻帧子图间相似度比较 ''' all_files = [] extensions = ['.png', '.jpg'] for dirpath, dirnames, filenames in os.walk(root_dir): filepaths = [] for filename in filenames: if os.path.basename(dirpath).find('subimgs') < 0: continue file, ext = os.path.splitext(filename) if ext in extensions: imgpath = os.path.join(dirpath, filename) filepaths.append(imgpath) nf = len(filepaths) if nf==0: continue fnma = os.path.basename(filepaths[0]).split('.')[0] imga = cv2.imread(filepaths[0]) ha, wa = imga.shape[:2] for i in range(1, nf): fnmb = os.path.basename(filepaths[i]).split('.')[0] imgb = cv2.imread(filepaths[i]) hb, wb = imgb.shape[:2] feats = inference_image_ReID(((imga, imgb))) similar = 1 - np.maximum(0.0, cdist(feats, feats, metric='cosine')) h, w = max((ha, hb)), max((wa, wb)) img = np.zeros(((h, 2*w, 3)), np.uint8) img[0:ha, 0:wa], img[0:hb, w:(w+wb)] = imga, imgb linewidth = max(round(((h+2*w))/2 * 0.001), 2) cv2.putText(img, text=f'{similar[0,1]:.2f}', # Text string to be drawn org=(max(w-20, 10), h-10), # Bottom-left corner of the text string fontFace=0, # Font type fontScale=linewidth/3, # Font scale factor color=(0, 0, 255), # Text color thickness=linewidth, # Thickness of the lines used to draw a text lineType=cv2.LINE_AA, # Line type ) spath = os.path.join(dirpath, 's'+fnma+'-vs-'+fnmb+'.png') cv2.imwrite(spath, img) fnma = os.path.basename(filepaths[i]).split('.')[0] imga = imgb.copy() ha, wa = imga.shape[:2] return def main(): root_dir = r"D:\contrast\dataset\result\20240723-112242_6923790709882" try: similarity_compare_sequence(root_dir) except Exception as e: print(f'Error: {e}') if __name__ == '__main__': # main() silimarity_compare()