Files
detecttracking/contrast/test_ori.py
2024-09-11 17:37:32 +08:00

477 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import os
import os.path as osp
import pdb
import numpy as np
import shutil
from scipy.spatial.distance import cdist
import torch
import torch.nn as nn
from PIL import Image
import json
from config import config as conf
from model import resnet18
# from model import (mobilevit_s, resnet14, resnet18, resnet34, resnet50, mobilenet_v2,
# MobileNetV3_Small, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
import matplotlib.pyplot as plt
import statistics
embedding_size = conf.embedding_size
img_size = conf.img_size
device = conf.device
def unique_image(pair_list) -> set:
"""Return unique image path in pair_list.txt"""
with open(pair_list, 'r') as fd:
pairs = fd.readlines()
unique = set()
for pair in pairs:
id1, id2, _ = pair.split()
unique.add(id1)
unique.add(id2)
return unique
def group_image(images: set, batch) -> list:
"""Group image paths by batch size"""
images = list(images)
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i: end])
return res
def _preprocess(images: list, transform) -> torch.Tensor:
res = []
for img in images:
im = Image.open(img)
im = transform(im)
res.append(im)
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
data = torch.stack(res)
return data
def test_preprocess(images: list, transform) -> torch.Tensor:
res = []
for img in images:
im = Image.open(img)
im = transform(im)
res.append(im)
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
data = torch.stack(res)
return data
def featurize(images: list, transform, net, device, train=False) -> dict:
"""featurize each image and save into a dictionary
Args:
images: image paths
transform: test transform
net: pretrained model
device: cpu or cuda
Returns:
Dict (key: imagePath, value: feature)
"""
if train:
data = _preprocess(images, transform)
data = data.to(device)
net = net.to(device)
with torch.no_grad():
features = net(data)
res = {img: feature for (img, feature) in zip(images, features)}
else:
data = test_preprocess(images, transform)
data = data.to(device)
net = net.to(device)
with torch.no_grad():
features = net(data)
res = {img: feature for (img, feature) in zip(images, features)}
return res
def inference_image(images: list, transform, net, device, bs=16, embedding_size=256) -> dict:
batch_patches = []
patches = []
for d, img in enumerate(images):
img = Image.open(img)
patch = transform(img)
if str(device) != "cpu":
patch = patch.to(device).half()
else:
patch = patch.to(device)
patches.append(patch)
if (d + 1) % bs == 0:
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
patches = []
if len(patches):
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
features = np.zeros((0, embedding_size), dtype=np.float32)
for patches in batch_patches:
pred = net(patches)
pred[torch.isinf(pred)] = 1.0
feat = pred.cpu().data.numpy()
features = np.vstack((features, feat))
return features
def featurize_1(images: list, transform, net, device, train=False) -> dict:
"""featurize each image and save into a dictionary
Args:
images: image paths
transform: test transform
net: pretrained model
device: cpu or cuda
Returns:
Dict (key: imagePath, value: feature)
"""
data = test_preprocess(images, transform)
data = data.to(device)
net = net.to(device)
with torch.no_grad():
features = net(data).data.numpy()
return features
def cosin_metric(x1, x2):
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
def threshold_search(y_score, y_true):
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
best_acc = 0
best_th = 0
for i in range(len(y_score)):
th = y_score[i]
y_test = (y_score >= th)
acc = np.mean((y_test == y_true).astype(int))
if acc > best_acc:
best_acc = acc
best_th = th
return best_acc, best_th
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg):
x = np.linspace(start=-1.0, stop=1.0, num=50, endpoint=True).tolist()
plt.figure(figsize=(10, 6))
plt.plot(x, recall, color='red', label='recall')
plt.plot(x, recall_TN, color='black', label='recall_TN')
plt.plot(x, PrecisePos, color='blue', label='PrecisePos')
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg')
plt.legend()
plt.xlabel('threshold')
# plt.ylabel('Similarity')
plt.grid(True, linestyle='--', alpha=0.5)
plt.savefig('accuracy_recall_grid.png')
plt.show()
plt.close()
def compute_accuracy_recall(score, labels):
th = 0.1
squence = np.linspace(-1, 1, num=50)
# squence = [0.4]
recall, PrecisePos, PreciseNeg, recall_TN = [], [], [], []
for th in squence:
t_score = (score > th)
t_labels = (labels == 1)
# print(t_score)
# print(t_labels)
TP = np.sum(np.logical_and(t_score, t_labels))
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
f_score = (score < th)
f_labels = (labels == 0)
TN = np.sum(np.logical_and(f_score, f_labels))
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
recall.append(0 if TP == 0 else TP / (TP + FN))
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
showgrid(recall, recall_TN, PrecisePos, PreciseNeg)
def compute_accuracy(feature_dict, pair_list, test_root):
with open(pair_list, 'r') as f:
pairs = f.readlines()
similarities = []
labels = []
for pair in pairs:
img1, img2, label = pair.split()
img1 = osp.join(test_root, img1)
img2 = osp.join(test_root, img2)
feature1 = feature_dict[img1].cpu().numpy()
feature2 = feature_dict[img2].cpu().numpy()
label = int(label)
similarity = cosin_metric(feature1, feature2)
similarities.append(similarity)
labels.append(label)
accuracy, threshold = threshold_search(similarities, labels)
# print('similarities >> {}'.format(similarities))
# print('labels >> {}'.format(labels))
compute_accuracy_recall(np.array(similarities), np.array(labels))
return accuracy, threshold
def deal_group_pair(pairList1, pairList2):
allsimilarity = []
one_similarity = []
for pair1 in pairList1:
for pair2 in pairList2:
similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
one_similarity.append(similarity)
allsimilarity.append(max(one_similarity)) # 最大值
# allsimilarity.append(sum(one_similarity)/len(one_similarity)) # 均值
# allsimilarity.append(statistics.median(one_similarity)) # 中位数
# print(allsimilarity)
# print(labels)
return allsimilarity
def compute_group_accuracy(content_list_read):
allSimilarity, allLabel= [], []
for data_loaded in content_list_read:
one_group_list = []
for i in range(2):
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
group = group_image(images, conf.test_batch_size)
d = featurize(group[0], conf.test_transform, model, conf.device)
one_group_list.append(d.values())
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
allLabel.append(data_loaded[-1])
allSimilarity.extend(similarity)
# print(allSimilarity)
# print(allLabel)
return allSimilarity, allLabel
def compute_contrast_accuracy(content_list_read):
npairs = 50
same_folder_pairs = content_list_read['same_folder_pairs']
cross_folder_pairs = content_list_read['cross_folder_pairs']
npairs = min((len(same_folder_pairs), len(cross_folder_pairs)))
same_pairs = same_folder_pairs[:npairs]
cross_pairs = cross_folder_pairs[:npairs]
same_pairs_similarity = []
for i in range(len(same_pairs)):
images_a = [osp.join(conf.test_val, img) for img in same_pairs[i][0]]
images_b = [osp.join(conf.test_val, img) for img in same_pairs[i][1]]
feats_a = inference_image(images_a, conf.test_transform, model, conf.device)
feats_b = inference_image(images_b, conf.test_transform, model, conf.device)
# matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
matrix = 1 - cdist(feats_a, feats_b, 'cosine')
feats_am = np.mean(feats_a, axis=0, keepdims=True)
feats_bm = np.mean(feats_b, axis=0, keepdims=True)
matrixm = 1- np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
same_pairs_similarity.append(np.mean(matrix))
'''保存相同 Barcode 图像对'''
# foldi = os.path.join('./result/same', f'{i}')
# if os.path.exists(foldi):
# shutil.rmtree(foldi)
# os.makedirs(foldi)
# else:
# os.makedirs(foldi)
# for ipt in range(len(images_a)):
# source_path = images_a[ipt]
# destination_path = os.path.join(foldi, f'a_{ipt}.png')
# shutil.copy2(source_path, destination_path)
# for ipt in range(len(images_b)):
# source_path = images_b[ipt]
# destination_path = os.path.join(foldi, f'b_{ipt}.png')
# shutil.copy2(source_path, destination_path)
cross_pairs_similarity = []
for i in range(len(cross_pairs)):
images_a = [osp.join(conf.test_val, img) for img in cross_pairs[i][0]]
images_b = [osp.join(conf.test_val, img) for img in cross_pairs[i][1]]
feats_a = inference_image(images_a, conf.test_transform, model, conf.device)
feats_b = inference_image(images_b, conf.test_transform, model, conf.device)
# matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
matrix = 1 - cdist(feats_a, feats_b, 'cosine')
feats_am = np.mean(feats_a, axis=0, keepdims=True)
feats_bm = np.mean(feats_b, axis=0, keepdims=True)
matrixm = 1- np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
cross_pairs_similarity.append(np.mean(matrix))
'''保存不同 Barcode 图像对'''
# foldi = os.path.join('./result/cross', f'{i}')
# if os.path.exists(foldi):
# shutil.rmtree(foldi)
# os.makedirs(foldi)
# else:
# os.makedirs(foldi)
# for ipt in range(len(images_a)):
# source_path = images_a[ipt]
# destination_path = os.path.join(foldi, f'a_{ipt}.png')
# shutil.copy2(source_path, destination_path)
# for ipt in range(len(images_b)):
# source_path = images_b[ipt]
# destination_path = os.path.join(foldi, f'b_{ipt}.png')
# shutil.copy2(source_path, destination_path)
Thresh = np.linspace(-0.2, 1, 100)
Same = np.array(same_pairs_similarity)
Cross = np.array(cross_pairs_similarity)
fig, axs = plt.subplots(2, 1)
axs[0].hist(Same, bins=60, edgecolor='black')
axs[0].set_xlim([-0.2, 1])
axs[0].set_title('Same Barcode')
axs[1].hist(Cross, bins=60, edgecolor='black')
axs[1].set_xlim([-0.2, 1])
axs[1].set_title('Cross Barcode')
TPFN = len(Same)
TNFP = len(Cross)
Recall_Pos, Recall_Neg = [], []
Precision_Pos, Precision_Neg = [], []
Correct = []
for th in Thresh:
TP = np.sum(Same > th)
FN = TPFN - TP
TN = np.sum(Cross < th)
FP = TNFP - TN
Recall_Pos.append(TP/TPFN)
Recall_Neg.append(TN/TNFP)
Precision_Pos.append(TP/(TP+FP))
Precision_Neg.append(TN/(TN+FN))
Correct.append((TN+TP)/(TPFN+TNFP))
fig, ax = plt.subplots()
ax.plot(Thresh, Correct, 'r', label='Correct: (TN+TP)/(TPFN+TNFP)')
ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP')
ax.plot(Thresh, Precision_Pos, 'c', label='Precision_Pos: TP/(TP+FP)')
ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)')
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.grid(True)
ax.set_title('PrecisePos & PreciseNeg')
ax.legend()
plt.show()
print("Haved done!!!")
if __name__ == '__main__':
# Network Setup
if conf.testbackbone == 'resnet18':
# model = ResIRSE(img_size, embedding_size, conf.drop_ratio).to(device)
model = resnet18().to(device)
# elif conf.testbackbone == 'resnet34':
# model = resnet34().to(device)
# elif conf.testbackbone == 'resnet50':
# model = resnet50().to(device)
# elif conf.testbackbone == 'mobilevit_s':
# model = mobilevit_s().to(device)
# elif conf.testbackbone == 'mobilenetv3':
# model = MobileNetV3_Small().to(device)
# elif conf.testbackbone == 'mobilenet_v1':
# model = mobilenet_v1().to(device)
# elif conf.testbackbone == 'PPLCNET_x1_0':
# model = PPLCNET_x1_0().to(device)
# elif conf.testbackbone == 'PPLCNET_x0_5':
# model = PPLCNET_x0_5().to(device)
# elif conf.backbone == 'PPLCNET_x2_5':
# model = PPLCNET_x2_5().to(device)
# elif conf.testbackbone == 'mobilenet_v2':
# model = mobilenet_v2().to(device)
# elif conf.testbackbone == 'resnet14':
# model = resnet14().to(device)
else:
raise ValueError('Have not model {}'.format(conf.backbone))
print('load model {} '.format(conf.testbackbone))
# model = nn.DataParallel(model).to(conf.device)
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
model.eval()
if not conf.group_test:
images = unique_image(conf.test_list)
images = [osp.join(conf.test_val, img) for img in images]
groups = group_image(images, conf.test_batch_size) ##根据batch_size取图片
feature_dict = dict()
for group in groups:
d = featurize(group, conf.test_transform, model, conf.device)
feature_dict.update(d)
# print('feature_dict', feature_dict)
accuracy, threshold = compute_accuracy(feature_dict, conf.test_list, conf.test_val)
print(
f"Test Model: {conf.test_model}\n"
f"Accuracy: {accuracy:.3f}\n"
f"Threshold: {threshold:.3f}\n"
)
elif conf.group_test:
"""
conf.test_val: 测试数据集地址
conf.test_group_json测试数据分组配置文件
"""
filename = conf.test_group_json
filename = "../cl/images_1.json"
with open(filename, 'r', encoding='utf-8') as file:
content_list_read = json.load(file)
compute_contrast_accuracy(content_list_read)
# =============================================================================
# Similarity, Label = compute_group_accuracy(content_list_read)
# print('allSimilarity >> {}'.format(Similarity))
# print('allLabel >> {}'.format(Label))
# compute_accuracy_recall(np.array(Similarity), np.array(Label))
# # compute_group_accuracy(data_loaded)
#
# =============================================================================