Files
ieemoo-ai-contrast/test_ori.py
2025-06-13 10:45:53 +08:00

332 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import os.path as osp
from typing import Dict, List, Set, Tuple
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import json
import matplotlib.pyplot as plt
# from config import config as conf
from tools.dataset import get_transform
from configs import trainer_tools
import yaml
with open('configs/test.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# Constants from config
embedding_size = conf["base"]["embedding_size"]
img_size = conf["transform"]["img_size"]
device = conf["base"]["device"]
def unique_image(pair_list: str) -> Set[str]:
unique_images = set()
try:
with open(pair_list, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
try:
img1, img2, _ = line.split()
unique_images.update([img1, img2])
except ValueError as e:
print(f"Skipping malformed line: {line}")
except IOError as e:
print(f"Error reading pair list file: {e}")
raise
return unique_images
def group_image(images: Set[str], batch_size: int) -> List[List[str]]:
"""
Group image paths into batches of specified size.
Args:
images: Set of image paths to group
batch_size: Number of images per batch
Returns:
List of batches, where each batch is a list of image paths
"""
image_list = list(images)
num_images = len(image_list)
batches = []
for i in range(0, num_images, batch_size):
batch_end = min(i + batch_size, num_images)
batches.append(image_list[i:batch_end])
return batches
def _preprocess(images: list, transform) -> torch.Tensor:
res = []
for img in images:
im = Image.open(img)
im = transform(im)
res.append(im)
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
data = torch.stack(res)
return data
def test_preprocess(images: list, transform) -> torch.Tensor:
res = []
for img in images:
im = Image.open(img)
if im.mode == 'RGBA':
im = im.convert('RGB')
im = transform(im)
res.append(im)
data = torch.stack(res)
return data
def featurize(
images: List[str],
transform: callable,
net: nn.Module,
device: torch.device,
train: bool = False
) -> Dict[str, torch.Tensor]:
try:
# Select appropriate preprocessing
preprocess_fn = _preprocess if train else test_preprocess
# Preprocess and move to device
data = preprocess_fn(images, transform)
data = data.to(device)
net = net.to(device)
# Extract features with automatic mixed precision
with torch.no_grad():
if conf['models']['half']:
data = data.half()
features = net(data)
# Create path-to-feature mapping
return {img: feature for img, feature in zip(images, features)}
except Exception as e:
print(f"Error in feature extraction: {e}")
raise
def cosin_metric(x1, x2):
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
def threshold_search(y_score, y_true):
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
best_acc = 0
best_th = 0
for i in range(len(y_score)):
th = y_score[i]
y_test = (y_score >= th)
acc = np.mean((y_test == y_true).astype(int))
if acc > best_acc:
best_acc = acc
best_th = th
return best_acc, best_th
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct):
x = np.linspace(start=0, stop=1.0, num=50, endpoint=True).tolist()
plt.figure(figsize=(10, 6))
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
plt.plot(x, Correct, color='m', label='Correct(TN+TP)/(TPFN+TNFP)')
plt.legend()
plt.xlabel('threshold')
# plt.ylabel('Similarity')
plt.grid(True, linestyle='--', alpha=0.5)
plt.savefig('grid.png')
plt.show()
plt.close()
def showHist(same, cross):
Same = np.array(same)
Cross = np.array(cross)
fig, axs = plt.subplots(2, 1)
axs[0].hist(Same, bins=50, edgecolor='black')
axs[0].set_xlim([-0.1, 1])
axs[0].set_title('Same Barcode')
axs[1].hist(Cross, bins=50, edgecolor='black')
axs[1].set_xlim([-0.1, 1])
axs[1].set_title('Cross Barcode')
plt.savefig('plot.png')
def compute_accuracy_recall(score, labels):
th = 0.1
squence = np.linspace(-1, 1, num=50)
recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], []
Same = score[:len(score) // 2]
Cross = score[len(score) // 2:]
for th in squence:
t_score = (score > th)
t_labels = (labels == 1)
TP = np.sum(np.logical_and(t_score, t_labels))
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
f_score = (score < th)
f_labels = (labels == 0)
TN = np.sum(np.logical_and(f_score, f_labels))
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
recall.append(0 if TP == 0 else TP / (TP + FN))
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN))
showHist(Same, Cross)
showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct)
def compute_accuracy(
feature_dict: Dict[str, torch.Tensor],
pair_list: str,
test_root: str
) -> Tuple[float, float]:
try:
with open(pair_list, 'r') as f:
pairs = f.readlines()
except IOError as e:
print(f"Error reading pair list: {e}")
raise
similarities = []
labels = []
for pair in pairs:
pair = pair.strip()
if not pair:
continue
try:
img1, img2, label = pair.split()
img1_path = osp.join(test_root, img1)
img2_path = osp.join(test_root, img2)
# Verify features exist
if img1_path not in feature_dict or img2_path not in feature_dict:
raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}")
# Get features and compute similarity
feat1 = feature_dict[img1_path].cpu().numpy()
feat2 = feature_dict[img2_path].cpu().numpy()
similarity = cosin_metric(feat1, feat2)
similarities.append(similarity)
labels.append(int(label))
except Exception as e:
print(f"Skipping invalid pair: {pair}. Error: {e}")
continue
# Find optimal threshold and accuracy
accuracy, threshold = threshold_search(similarities, labels)
compute_accuracy_recall(np.array(similarities), np.array(labels))
return accuracy, threshold
def deal_group_pair(pairList1, pairList2):
allsimilarity = []
one_similarity = []
for pair1 in pairList1:
for pair2 in pairList2:
similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
one_similarity.append(similarity)
allsimilarity.append(max(one_similarity)) # 最大值
# allsimilarity.append(sum(one_similarity) / len(one_similarity)) # 均值
# allsimilarity.append(statistics.median(one_similarity)) # 中位数
# print(allsimilarity)
# print(labels)
return allsimilarity
def compute_group_accuracy(content_list_read):
allSimilarity, allLabel = [], []
Same, Cross = [], []
for data_loaded in content_list_read:
print(data_loaded)
one_group_list = []
try:
for i in range(2):
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
group = group_image(images, conf.test_batch_size)
d = featurize(group[0], conf.test_transform, model, conf.device)
one_group_list.append(d.values())
if data_loaded[-1] == '1':
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
Same.append(similarity)
else:
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
Cross.append(similarity)
allLabel.append(data_loaded[-1])
allSimilarity.extend(similarity)
except Exception as e:
continue
# print(allSimilarity)
# print(allLabel)
return allSimilarity, allLabel
def init_model():
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device'])
else:
raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']}))
print('load model {} '.format(conf['models']['backbone']))
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
model = nn.DataParallel(model).to(conf['base']['device'])
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
if conf['models']['half']:
model.half()
first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
else:
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
if conf['models']['half']:
model.half()
first_param_dtype = next(model.parameters()).dtype
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
return model
if __name__ == '__main__':
model = init_model()
model.eval()
if not conf['data']['group_test']:
images = unique_image(conf['data']['test_list'])
images = [osp.join(conf['data']['test_dir'], img) for img in images]
groups = group_image(images, conf['data']['test_batch_size']) # 根据batch_size取图片
feature_dict = dict()
_, test_transform = get_transform(conf)
for group in groups:
d = featurize(group, test_transform, model, conf['base']['device'])
feature_dict.update(d)
accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
print(
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
)
elif conf['data']['group_test']:
filename = conf['data']['test_group_json']
with open(filename, 'r', encoding='utf-8') as file:
content_list_read = json.load(file)
Similarity, Label = compute_group_accuracy(content_list_read)
compute_accuracy_recall(np.array(Similarity), np.array(Label))
# compute_group_accuracy(data_loaded)