373 lines
13 KiB
Python
373 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
||
import os.path as osp
|
||
from typing import Dict, List, Set, Tuple
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
from PIL import Image
|
||
import json
|
||
import matplotlib.pyplot as plt
|
||
|
||
# from config import config as conf
|
||
from tools.dataset import get_transform
|
||
from tools.image_joint import merge_imgs
|
||
from tools.getHeatMap import cal_cam
|
||
from configs import trainer_tools
|
||
import yaml
|
||
from datetime import datetime
|
||
|
||
with open('../configs/test.yml', 'r') as f:
|
||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||
|
||
# Constants from config
|
||
embedding_size = conf["base"]["embedding_size"]
|
||
img_size = conf["transform"]["img_size"]
|
||
device = conf["base"]["device"]
|
||
|
||
|
||
def unique_image(pair_list: str) -> Set[str]:
|
||
unique_images = set()
|
||
try:
|
||
with open(pair_list, 'r') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
img1, img2, _ = line.split()
|
||
unique_images.update([img1, img2])
|
||
except ValueError as e:
|
||
print(f"Skipping malformed line: {line}")
|
||
except IOError as e:
|
||
print(f"Error reading pair list file: {e}")
|
||
raise
|
||
|
||
return unique_images
|
||
|
||
|
||
def group_image(images: Set[str], batch_size: int) -> List[List[str]]:
|
||
"""
|
||
Group image paths into batches of specified size.
|
||
|
||
Args:
|
||
images: Set of image paths to group
|
||
batch_size: Number of images per batch
|
||
|
||
Returns:
|
||
List of batches, where each batch is a list of image paths
|
||
"""
|
||
image_list = list(images)
|
||
num_images = len(image_list)
|
||
batches = []
|
||
|
||
for i in range(0, num_images, batch_size):
|
||
batch_end = min(i + batch_size, num_images)
|
||
batches.append(image_list[i:batch_end])
|
||
|
||
return batches
|
||
|
||
|
||
def _preprocess(images: list, transform) -> torch.Tensor:
|
||
res = []
|
||
for img in images:
|
||
im = Image.open(img)
|
||
im = transform(im)
|
||
res.append(im)
|
||
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
||
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
||
data = torch.stack(res)
|
||
return data
|
||
|
||
|
||
def test_preprocess(images: list, transform) -> torch.Tensor:
|
||
res = []
|
||
for img in images:
|
||
im = Image.open(img)
|
||
if im.mode == 'RGBA':
|
||
im = im.convert('RGB')
|
||
im = transform(im)
|
||
res.append(im)
|
||
data = torch.stack(res)
|
||
return data
|
||
|
||
|
||
def featurize(
|
||
images: List[str],
|
||
transform: callable,
|
||
net: nn.Module,
|
||
device: torch.device,
|
||
train: bool = False
|
||
) -> Dict[str, torch.Tensor]:
|
||
try:
|
||
# Select appropriate preprocessing
|
||
preprocess_fn = _preprocess if train else test_preprocess
|
||
|
||
# Preprocess and move to device
|
||
data = preprocess_fn(images, transform)
|
||
data = data.to(device)
|
||
net = net.to(device)
|
||
|
||
# Extract features with automatic mixed precision
|
||
with torch.no_grad():
|
||
if conf['models']['half']:
|
||
data = data.half()
|
||
features = net(data)
|
||
# Create path-to-feature mapping
|
||
return {img: feature for img, feature in zip(images, features)}
|
||
|
||
except Exception as e:
|
||
print(f"Error in feature extraction: {e}")
|
||
raise
|
||
|
||
|
||
def cosin_metric(x1, x2):
|
||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||
|
||
|
||
def threshold_search(y_score, y_true):
|
||
y_score = np.asarray(y_score)
|
||
y_true = np.asarray(y_true)
|
||
best_acc = 0
|
||
best_th = 0
|
||
for i in range(len(y_score)):
|
||
th = y_score[i]
|
||
y_test = (y_score >= th)
|
||
acc = np.mean((y_test == y_true).astype(int))
|
||
if acc > best_acc:
|
||
best_acc = acc
|
||
best_th = th
|
||
return best_acc, best_th
|
||
|
||
|
||
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
|
||
x = np.linspace(start=-1, stop=1.0, num=100, endpoint=True).tolist()
|
||
plt.figure(figsize=(10, 6))
|
||
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||
plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)')
|
||
plt.legend()
|
||
plt.xlabel('threshold')
|
||
# plt.ylabel('Similarity')
|
||
|
||
plt.grid(True, linestyle='--', alpha=0.5)
|
||
plt.savefig('grid.png')
|
||
plt.show()
|
||
plt.close()
|
||
|
||
|
||
def showHist(same, cross):
|
||
Same = np.array(same)
|
||
Cross = np.array(cross)
|
||
|
||
fig, axs = plt.subplots(2, 1)
|
||
axs[0].hist(Same, bins=100, edgecolor='black')
|
||
axs[0].set_xlim([-1, 1])
|
||
axs[0].set_title('Same Barcode')
|
||
|
||
axs[1].hist(Cross, bins=100, edgecolor='black')
|
||
axs[1].set_xlim([-1, 1])
|
||
axs[1].set_title('Cross Barcode')
|
||
plt.savefig('plot.png')
|
||
|
||
|
||
def compute_accuracy_recall(score, labels):
|
||
th = 0.1
|
||
squence = np.linspace(-1, 1, num=100)
|
||
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
|
||
Same = score[:len(score) // 2]
|
||
Cross = score[len(score) // 2:]
|
||
for th in squence:
|
||
t_score = (score > th)
|
||
t_labels = (labels == 1)
|
||
TP = np.sum(np.logical_and(t_score, t_labels))
|
||
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||
f_score = (score < th)
|
||
f_labels = (labels == 0)
|
||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||
|
||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
|
||
print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1],
|
||
PreciseNeg[-1], recall_TN[-1]))
|
||
showHist(Same, Cross)
|
||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
|
||
|
||
|
||
def compute_accuracy(
|
||
feature_dict: Dict[str, torch.Tensor],
|
||
cam: cal_cam,
|
||
) -> Tuple[float, float]:
|
||
try:
|
||
pair_list = conf['data']['test_list']
|
||
test_root = conf['data']['test_dir']
|
||
with open(pair_list, 'r') as f:
|
||
pairs = f.readlines()
|
||
except IOError as e:
|
||
print(f"Error reading pair list: {e}")
|
||
raise
|
||
|
||
similarities = []
|
||
labels = []
|
||
|
||
for pair in pairs:
|
||
pair = pair.strip()
|
||
if not pair:
|
||
continue
|
||
|
||
# try:
|
||
print(f"Processing pair: {pair}")
|
||
img1, img2, label = pair.split()
|
||
img1_path = osp.join(test_root, img1)
|
||
img2_path = osp.join(test_root, img2)
|
||
|
||
# Verify features exist
|
||
if img1_path not in feature_dict or img2_path not in feature_dict:
|
||
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
|
||
|
||
# Get features and compute similarity
|
||
feat1 = feature_dict[img1_path].cpu().numpy()
|
||
feat2 = feature_dict[img2_path].cpu().numpy()
|
||
similarity = cosin_metric(feat1, feat2)
|
||
print('{} vs {}: {}'.format(img1_path, img2_path, similarity))
|
||
if conf['data']['save_image_joint']:
|
||
merge_imgs(img1_path,
|
||
img2_path,
|
||
conf,
|
||
similarity,
|
||
label,
|
||
cam)
|
||
similarities.append(similarity)
|
||
labels.append(int(label))
|
||
|
||
# except Exception as e:
|
||
# print(f"Skipping invalid pair: {pair}. Error: {e}")
|
||
# continue
|
||
|
||
# Find optimal threshold and accuracy
|
||
accuracy, threshold = threshold_search(similarities, labels)
|
||
compute_accuracy_recall(np.array(similarities), np.array(labels))
|
||
|
||
return accuracy, threshold
|
||
|
||
|
||
def deal_group_pair(pairList1, pairList2):
|
||
allsimilarity = []
|
||
one_similarity = []
|
||
for pair1 in pairList1:
|
||
for pair2 in pairList2:
|
||
similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||
one_similarity.append(similarity)
|
||
allsimilarity.append(max(one_similarity)) # 最大值
|
||
# allsimilarity.append(sum(one_similarity) / len(one_similarity)) # 均值
|
||
# allsimilarity.append(statistics.median(one_similarity)) # 中位数
|
||
# print(allsimilarity)
|
||
# print(labels)
|
||
return allsimilarity
|
||
|
||
|
||
def compute_group_accuracy(content_list_read):
|
||
allSimilarity, allLabel = [], []
|
||
Same, Cross = [], []
|
||
for data_loaded in content_list_read:
|
||
print(data_loaded)
|
||
one_group_list = []
|
||
try:
|
||
for i in range(2):
|
||
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||
group = group_image(images, conf.test_batch_size)
|
||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||
one_group_list.append(d.values())
|
||
if data_loaded[-1] == '1':
|
||
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||
Same.append(similarity)
|
||
else:
|
||
similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1]))
|
||
Cross.append(similarity)
|
||
allLabel.append(data_loaded[-1])
|
||
allSimilarity.extend(similarity)
|
||
except Exception as e:
|
||
continue
|
||
# print(allSimilarity)
|
||
# print(allLabel)
|
||
return allSimilarity, allLabel
|
||
|
||
|
||
def init_model():
|
||
tr_tools = trainer_tools(conf)
|
||
backbone_mapping = tr_tools.get_backbone()
|
||
if conf['models']['backbone'] in backbone_mapping:
|
||
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
|
||
else:
|
||
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
|
||
print('load model {} '.format(conf['models']['backbone']))
|
||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||
model = nn.DataParallel(model).to(conf['base']['device'])
|
||
###############正常模型加载################
|
||
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||
map_location=conf['base']['device']))
|
||
#######################################
|
||
####### 对比学习模型临时运用###
|
||
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
|
||
# new_state_dict = {}
|
||
# for k, v in state_dict.items():
|
||
# new_key = k.replace("module.base_model.", "module.")
|
||
# new_state_dict[new_key] = v
|
||
# model.load_state_dict(new_state_dict, strict=False)
|
||
###########################
|
||
if conf['models']['half']:
|
||
model.half()
|
||
first_param_dtype = next(model.parameters()).dtype
|
||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||
else:
|
||
try:
|
||
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||
map_location=conf['base']['device']))
|
||
except:
|
||
state_dict = torch.load(conf['models']['model_path'],
|
||
map_location=conf['base']['device'])
|
||
new_state_dict = {}
|
||
for k, v in state_dict.items():
|
||
new_key = k.replace("module.", "")
|
||
new_state_dict[new_key] = v
|
||
model.load_state_dict(new_state_dict, strict=False)
|
||
|
||
if conf['models']['half']:
|
||
model.half()
|
||
first_param_dtype = next(model.parameters()).dtype
|
||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||
return model
|
||
|
||
|
||
if __name__ == '__main__':
|
||
model = init_model()
|
||
model.eval()
|
||
cam = cal_cam(model, conf)
|
||
if not conf['data']['group_test']:
|
||
images = unique_image(conf['data']['test_list'])
|
||
images = [osp.join(conf['data']['test_dir'], img) for img in images]
|
||
groups = group_image(images, conf['data']['test_batch_size']) # 根据batch_size取图片
|
||
feature_dict = dict()
|
||
_, test_transform = get_transform(conf)
|
||
for group in groups:
|
||
d = featurize(group, test_transform, model, conf['base']['device'])
|
||
feature_dict.update(d)
|
||
accuracy, threshold = compute_accuracy(feature_dict, cam)
|
||
print(
|
||
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
|
||
)
|
||
elif conf['data']['group_test']:
|
||
filename = conf['data']['test_group_json']
|
||
with open(filename, 'r', encoding='utf-8') as file:
|
||
content_list_read = json.load(file)
|
||
Similarity, Label = compute_group_accuracy(content_list_read)
|
||
compute_accuracy_recall(np.array(Similarity), np.array(Label))
|
||
# compute_group_accuracy(data_loaded)
|