更新 detacttracking
This commit is contained in:
200
detecttracking/contrast/seqfeat_compare.py
Normal file
200
detecttracking/contrast/seqfeat_compare.py
Normal file
@ -0,0 +1,200 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Fri Aug 9 10:36:45 2024
|
||||
分析图像对间的相似度
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import sys
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
''' 加载 LC 定义的模型形式'''
|
||||
from config import config as conf
|
||||
from model import resnet18 as resnet18
|
||||
from test_ori import inference_image
|
||||
##============ load resnet mdoel
|
||||
model = resnet18().to(conf.device)
|
||||
# model = nn.DataParallel(model).to(conf.device)
|
||||
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
|
||||
model.eval()
|
||||
print('load model {} '.format(conf.testbackbone))
|
||||
|
||||
IMG_FORMAT = ['.bmp', '.jpg', '.JPG', '.jpeg', '.png']
|
||||
|
||||
# =============================================================================
|
||||
# ''' 加载REID中定义的模型形式'''
|
||||
# sys.path.append(r"D:\DetectTracking")
|
||||
# from tracking.trackers.reid.reid_interface import ReIDInterface
|
||||
# from tracking.trackers.reid.config import config as ReIDConfig
|
||||
# ReIDEncoder = ReIDInterface(ReIDConfig)
|
||||
#
|
||||
# def inference_image_ReID(images):
|
||||
# batch_patches = []
|
||||
# patches = []
|
||||
# for d, img1 in enumerate(images):
|
||||
#
|
||||
#
|
||||
# img = img1[:, :, ::-1].copy() # the model expects RGB inputs
|
||||
# patch = ReIDEncoder.transform(img)
|
||||
#
|
||||
# # patch = patch.to(device=self.device).half()
|
||||
# if str(ReIDEncoder.device) != "cpu":
|
||||
# patch = patch.to(device=ReIDEncoder.device).half()
|
||||
# else:
|
||||
# patch = patch.to(device=ReIDEncoder.device)
|
||||
#
|
||||
# patches.append(patch)
|
||||
# if (d + 1) % ReIDEncoder.batch_size == 0:
|
||||
# patches = torch.stack(patches, dim=0)
|
||||
# batch_patches.append(patches)
|
||||
# patches = []
|
||||
#
|
||||
# if len(patches):
|
||||
# patches = torch.stack(patches, dim=0)
|
||||
# batch_patches.append(patches)
|
||||
#
|
||||
# features = np.zeros((0, ReIDEncoder.embedding_size))
|
||||
# for patches in batch_patches:
|
||||
# pred = ReIDEncoder.model(patches)
|
||||
# pred[torch.isinf(pred)] = 1.0
|
||||
# feat = pred.cpu().data.numpy()
|
||||
# features = np.vstack((features, feat))
|
||||
#
|
||||
# return features
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def silimarity_compare():
|
||||
|
||||
imgpaths = r"D:\DetectTracking\contrast\images\2"
|
||||
|
||||
|
||||
filepaths = []
|
||||
for root, dirs, filenames in os.walk(imgpaths):
|
||||
for filename in filenames:
|
||||
file, ext = os.path.splitext(filename)
|
||||
if ext not in IMG_FORMAT: continue
|
||||
|
||||
file_path = os.path.join(root, filename)
|
||||
filepaths.append(file_path)
|
||||
|
||||
feature = inference_image(filepaths, conf.test_transform, model, conf.device)
|
||||
feature /= np.linalg.norm(feature, axis=1)[:, None]
|
||||
|
||||
similar = 1 - np.maximum(0.0, cdist(feature, feature, metric='cosine'))
|
||||
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
|
||||
def similarity_compare_sequence(root_dir):
|
||||
'''
|
||||
root_dir:包含 "subimgs"字段的文件夹中图像为 subimg子图
|
||||
功能:相邻帧子图间相似度比较
|
||||
|
||||
|
||||
'''
|
||||
|
||||
all_files = []
|
||||
extensions = ['.png', '.jpg']
|
||||
for dirpath, dirnames, filenames in os.walk(root_dir):
|
||||
filepaths = []
|
||||
for filename in filenames:
|
||||
if os.path.basename(dirpath).find('subimgs') < 0:
|
||||
continue
|
||||
file, ext = os.path.splitext(filename)
|
||||
if ext in extensions:
|
||||
imgpath = os.path.join(dirpath, filename)
|
||||
filepaths.append(imgpath)
|
||||
nf = len(filepaths)
|
||||
if nf==0:
|
||||
continue
|
||||
|
||||
fnma = os.path.basename(filepaths[0]).split('.')[0]
|
||||
imga = cv2.imread(filepaths[0])
|
||||
ha, wa = imga.shape[:2]
|
||||
|
||||
for i in range(1, nf):
|
||||
fnmb = os.path.basename(filepaths[i]).split('.')[0]
|
||||
|
||||
imgb = cv2.imread(filepaths[i])
|
||||
hb, wb = imgb.shape[:2]
|
||||
|
||||
|
||||
feats = inference_image_ReID(((imga, imgb)))
|
||||
|
||||
similar = 1 - np.maximum(0.0, cdist(feats, feats, metric='cosine'))
|
||||
|
||||
|
||||
h, w = max((ha, hb)), max((wa, wb))
|
||||
img = np.zeros(((h, 2*w, 3)), np.uint8)
|
||||
img[0:ha, 0:wa], img[0:hb, w:(w+wb)] = imga, imgb
|
||||
|
||||
linewidth = max(round(((h+2*w))/2 * 0.001), 2)
|
||||
cv2.putText(img,
|
||||
text=f'{similar[0,1]:.2f}', # Text string to be drawn
|
||||
org=(max(w-20, 10), h-10), # Bottom-left corner of the text string
|
||||
fontFace=0, # Font type
|
||||
fontScale=linewidth/3, # Font scale factor
|
||||
color=(0, 0, 255), # Text color
|
||||
thickness=linewidth, # Thickness of the lines used to draw a text
|
||||
lineType=cv2.LINE_AA, # Line type
|
||||
)
|
||||
spath = os.path.join(dirpath, 's'+fnma+'-vs-'+fnmb+'.png')
|
||||
cv2.imwrite(spath, img)
|
||||
|
||||
|
||||
fnma = os.path.basename(filepaths[i]).split('.')[0]
|
||||
imga = imgb.copy()
|
||||
ha, wa = imga.shape[:2]
|
||||
|
||||
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
root_dir = r"D:\contrast\dataset\result\20240723-112242_6923790709882"
|
||||
|
||||
try:
|
||||
similarity_compare_sequence(root_dir)
|
||||
except Exception as e:
|
||||
print(f'Error: {e}')
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# main()
|
||||
|
||||
silimarity_compare()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user