Files
ieemoo-ai-review/detecttracking/contrast/seqfeat_compare.py
2025-01-22 13:16:44 +08:00

200 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Fri Aug 9 10:36:45 2024
分析图像对间的相似度
@author: ym
"""
import os
import cv2
import numpy as np
import torch
import sys
from scipy.spatial.distance import cdist
''' 加载 LC 定义的模型形式'''
from config import config as conf
from model import resnet18 as resnet18
from test_ori import inference_image
##============ load resnet mdoel
model = resnet18().to(conf.device)
# model = nn.DataParallel(model).to(conf.device)
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
model.eval()
print('load model {} '.format(conf.testbackbone))
IMG_FORMAT = ['.bmp', '.jpg', '.JPG', '.jpeg', '.png']
# =============================================================================
# ''' 加载REID中定义的模型形式'''
# sys.path.append(r"D:\DetectTracking")
# from tracking.trackers.reid.reid_interface import ReIDInterface
# from tracking.trackers.reid.config import config as ReIDConfig
# ReIDEncoder = ReIDInterface(ReIDConfig)
#
# def inference_image_ReID(images):
# batch_patches = []
# patches = []
# for d, img1 in enumerate(images):
#
#
# img = img1[:, :, ::-1].copy() # the model expects RGB inputs
# patch = ReIDEncoder.transform(img)
#
# # patch = patch.to(device=self.device).half()
# if str(ReIDEncoder.device) != "cpu":
# patch = patch.to(device=ReIDEncoder.device).half()
# else:
# patch = patch.to(device=ReIDEncoder.device)
#
# patches.append(patch)
# if (d + 1) % ReIDEncoder.batch_size == 0:
# patches = torch.stack(patches, dim=0)
# batch_patches.append(patches)
# patches = []
#
# if len(patches):
# patches = torch.stack(patches, dim=0)
# batch_patches.append(patches)
#
# features = np.zeros((0, ReIDEncoder.embedding_size))
# for patches in batch_patches:
# pred = ReIDEncoder.model(patches)
# pred[torch.isinf(pred)] = 1.0
# feat = pred.cpu().data.numpy()
# features = np.vstack((features, feat))
#
# return features
# =============================================================================
def silimarity_compare():
imgpaths = r"D:\DetectTracking\contrast\images\2"
filepaths = []
for root, dirs, filenames in os.walk(imgpaths):
for filename in filenames:
file, ext = os.path.splitext(filename)
if ext not in IMG_FORMAT: continue
file_path = os.path.join(root, filename)
filepaths.append(file_path)
feature = inference_image(filepaths, conf.test_transform, model, conf.device)
feature /= np.linalg.norm(feature, axis=1)[:, None]
similar = 1 - np.maximum(0.0, cdist(feature, feature, metric='cosine'))
print("Done!")
def similarity_compare_sequence(root_dir):
'''
root_dir包含 "subimgs"字段的文件夹中图像为 subimg子图
功能:相邻帧子图间相似度比较
'''
all_files = []
extensions = ['.png', '.jpg']
for dirpath, dirnames, filenames in os.walk(root_dir):
filepaths = []
for filename in filenames:
if os.path.basename(dirpath).find('subimgs') < 0:
continue
file, ext = os.path.splitext(filename)
if ext in extensions:
imgpath = os.path.join(dirpath, filename)
filepaths.append(imgpath)
nf = len(filepaths)
if nf==0:
continue
fnma = os.path.basename(filepaths[0]).split('.')[0]
imga = cv2.imread(filepaths[0])
ha, wa = imga.shape[:2]
for i in range(1, nf):
fnmb = os.path.basename(filepaths[i]).split('.')[0]
imgb = cv2.imread(filepaths[i])
hb, wb = imgb.shape[:2]
feats = inference_image_ReID(((imga, imgb)))
similar = 1 - np.maximum(0.0, cdist(feats, feats, metric='cosine'))
h, w = max((ha, hb)), max((wa, wb))
img = np.zeros(((h, 2*w, 3)), np.uint8)
img[0:ha, 0:wa], img[0:hb, w:(w+wb)] = imga, imgb
linewidth = max(round(((h+2*w))/2 * 0.001), 2)
cv2.putText(img,
text=f'{similar[0,1]:.2f}', # Text string to be drawn
org=(max(w-20, 10), h-10), # Bottom-left corner of the text string
fontFace=0, # Font type
fontScale=linewidth/3, # Font scale factor
color=(0, 0, 255), # Text color
thickness=linewidth, # Thickness of the lines used to draw a text
lineType=cv2.LINE_AA, # Line type
)
spath = os.path.join(dirpath, 's'+fnma+'-vs-'+fnmb+'.png')
cv2.imwrite(spath, img)
fnma = os.path.basename(filepaths[i]).split('.')[0]
imga = imgb.copy()
ha, wa = imga.shape[:2]
return
def main():
root_dir = r"D:\contrast\dataset\result\20240723-112242_6923790709882"
try:
similarity_compare_sequence(root_dir)
except Exception as e:
print(f'Error: {e}')
if __name__ == '__main__':
# main()
silimarity_compare()