477 lines
16 KiB
Python
477 lines
16 KiB
Python
# -*- coding: utf-8 -*-
|
||
import os
|
||
import os.path as osp
|
||
import pdb
|
||
import numpy as np
|
||
import shutil
|
||
from scipy.spatial.distance import cdist
|
||
import torch
|
||
import torch.nn as nn
|
||
from PIL import Image
|
||
import json
|
||
from config import config as conf
|
||
from model import resnet18
|
||
|
||
# from model import (mobilevit_s, resnet14, resnet18, resnet34, resnet50, mobilenet_v2,
|
||
# MobileNetV3_Small, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||
|
||
|
||
import matplotlib.pyplot as plt
|
||
import statistics
|
||
embedding_size = conf.embedding_size
|
||
img_size = conf.img_size
|
||
device = conf.device
|
||
|
||
|
||
def unique_image(pair_list) -> set:
|
||
"""Return unique image path in pair_list.txt"""
|
||
with open(pair_list, 'r') as fd:
|
||
pairs = fd.readlines()
|
||
unique = set()
|
||
for pair in pairs:
|
||
id1, id2, _ = pair.split()
|
||
unique.add(id1)
|
||
unique.add(id2)
|
||
return unique
|
||
|
||
|
||
def group_image(images: set, batch) -> list:
|
||
"""Group image paths by batch size"""
|
||
images = list(images)
|
||
size = len(images)
|
||
res = []
|
||
for i in range(0, size, batch):
|
||
end = min(batch + i, size)
|
||
res.append(images[i: end])
|
||
return res
|
||
|
||
|
||
def _preprocess(images: list, transform) -> torch.Tensor:
|
||
res = []
|
||
for img in images:
|
||
im = Image.open(img)
|
||
im = transform(im)
|
||
res.append(im)
|
||
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
||
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
||
data = torch.stack(res)
|
||
return data
|
||
|
||
|
||
def test_preprocess(images: list, transform) -> torch.Tensor:
|
||
res = []
|
||
for img in images:
|
||
im = Image.open(img)
|
||
im = transform(im)
|
||
res.append(im)
|
||
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
||
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
||
data = torch.stack(res)
|
||
return data
|
||
|
||
|
||
def featurize(images: list, transform, net, device, train=False) -> dict:
|
||
"""featurize each image and save into a dictionary
|
||
Args:
|
||
images: image paths
|
||
transform: test transform
|
||
net: pretrained model
|
||
device: cpu or cuda
|
||
Returns:
|
||
Dict (key: imagePath, value: feature)
|
||
"""
|
||
if train:
|
||
data = _preprocess(images, transform)
|
||
data = data.to(device)
|
||
net = net.to(device)
|
||
with torch.no_grad():
|
||
features = net(data)
|
||
res = {img: feature for (img, feature) in zip(images, features)}
|
||
else:
|
||
data = test_preprocess(images, transform)
|
||
data = data.to(device)
|
||
net = net.to(device)
|
||
with torch.no_grad():
|
||
features = net(data)
|
||
res = {img: feature for (img, feature) in zip(images, features)}
|
||
return res
|
||
|
||
def inference_image(images: list, transform, net, device, bs=16, embedding_size=256) -> dict:
|
||
batch_patches = []
|
||
patches = []
|
||
for d, img in enumerate(images):
|
||
img = Image.open(img)
|
||
patch = transform(img)
|
||
|
||
if str(device) != "cpu":
|
||
patch = patch.to(device).half()
|
||
else:
|
||
patch = patch.to(device)
|
||
|
||
patches.append(patch)
|
||
if (d + 1) % bs == 0:
|
||
patches = torch.stack(patches, dim=0)
|
||
batch_patches.append(patches)
|
||
patches = []
|
||
|
||
if len(patches):
|
||
patches = torch.stack(patches, dim=0)
|
||
batch_patches.append(patches)
|
||
|
||
features = np.zeros((0, embedding_size), dtype=np.float32)
|
||
for patches in batch_patches:
|
||
pred = net(patches)
|
||
pred[torch.isinf(pred)] = 1.0
|
||
feat = pred.cpu().data.numpy()
|
||
features = np.vstack((features, feat))
|
||
|
||
|
||
|
||
return features
|
||
|
||
|
||
|
||
def featurize_1(images: list, transform, net, device, train=False) -> dict:
|
||
"""featurize each image and save into a dictionary
|
||
Args:
|
||
images: image paths
|
||
transform: test transform
|
||
net: pretrained model
|
||
device: cpu or cuda
|
||
Returns:
|
||
Dict (key: imagePath, value: feature)
|
||
"""
|
||
|
||
data = test_preprocess(images, transform)
|
||
data = data.to(device)
|
||
net = net.to(device)
|
||
with torch.no_grad():
|
||
features = net(data).data.numpy()
|
||
|
||
return features
|
||
|
||
|
||
|
||
|
||
def cosin_metric(x1, x2):
|
||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||
|
||
|
||
def threshold_search(y_score, y_true):
|
||
y_score = np.asarray(y_score)
|
||
y_true = np.asarray(y_true)
|
||
best_acc = 0
|
||
best_th = 0
|
||
for i in range(len(y_score)):
|
||
th = y_score[i]
|
||
y_test = (y_score >= th)
|
||
acc = np.mean((y_test == y_true).astype(int))
|
||
if acc > best_acc:
|
||
best_acc = acc
|
||
best_th = th
|
||
return best_acc, best_th
|
||
|
||
|
||
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg):
|
||
x = np.linspace(start=-1.0, stop=1.0, num=50, endpoint=True).tolist()
|
||
plt.figure(figsize=(10, 6))
|
||
plt.plot(x, recall, color='red', label='recall')
|
||
plt.plot(x, recall_TN, color='black', label='recall_TN')
|
||
plt.plot(x, PrecisePos, color='blue', label='PrecisePos')
|
||
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg')
|
||
plt.legend()
|
||
plt.xlabel('threshold')
|
||
# plt.ylabel('Similarity')
|
||
plt.grid(True, linestyle='--', alpha=0.5)
|
||
plt.savefig('accuracy_recall_grid.png')
|
||
plt.show()
|
||
plt.close()
|
||
|
||
|
||
def compute_accuracy_recall(score, labels):
|
||
th = 0.1
|
||
squence = np.linspace(-1, 1, num=50)
|
||
# squence = [0.4]
|
||
recall, PrecisePos, PreciseNeg, recall_TN = [], [], [], []
|
||
for th in squence:
|
||
t_score = (score > th)
|
||
t_labels = (labels == 1)
|
||
# print(t_score)
|
||
# print(t_labels)
|
||
TP = np.sum(np.logical_and(t_score, t_labels))
|
||
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||
f_score = (score < th)
|
||
f_labels = (labels == 0)
|
||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||
|
||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg)
|
||
|
||
|
||
def compute_accuracy(feature_dict, pair_list, test_root):
|
||
with open(pair_list, 'r') as f:
|
||
pairs = f.readlines()
|
||
|
||
similarities = []
|
||
labels = []
|
||
for pair in pairs:
|
||
img1, img2, label = pair.split()
|
||
img1 = osp.join(test_root, img1)
|
||
img2 = osp.join(test_root, img2)
|
||
feature1 = feature_dict[img1].cpu().numpy()
|
||
feature2 = feature_dict[img2].cpu().numpy()
|
||
label = int(label)
|
||
|
||
similarity = cosin_metric(feature1, feature2)
|
||
similarities.append(similarity)
|
||
labels.append(label)
|
||
|
||
accuracy, threshold = threshold_search(similarities, labels)
|
||
# print('similarities >> {}'.format(similarities))
|
||
# print('labels >> {}'.format(labels))
|
||
compute_accuracy_recall(np.array(similarities), np.array(labels))
|
||
return accuracy, threshold
|
||
|
||
|
||
def deal_group_pair(pairList1, pairList2):
|
||
allsimilarity = []
|
||
one_similarity = []
|
||
for pair1 in pairList1:
|
||
for pair2 in pairList2:
|
||
similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||
one_similarity.append(similarity)
|
||
allsimilarity.append(max(one_similarity)) # 最大值
|
||
# allsimilarity.append(sum(one_similarity)/len(one_similarity)) # 均值
|
||
# allsimilarity.append(statistics.median(one_similarity)) # 中位数
|
||
# print(allsimilarity)
|
||
# print(labels)
|
||
return allsimilarity
|
||
|
||
def compute_group_accuracy(content_list_read):
|
||
allSimilarity, allLabel= [], []
|
||
for data_loaded in content_list_read:
|
||
one_group_list = []
|
||
for i in range(2):
|
||
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||
group = group_image(images, conf.test_batch_size)
|
||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||
one_group_list.append(d.values())
|
||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||
allLabel.append(data_loaded[-1])
|
||
allSimilarity.extend(similarity)
|
||
# print(allSimilarity)
|
||
# print(allLabel)
|
||
return allSimilarity, allLabel
|
||
|
||
def compute_contrast_accuracy(content_list_read):
|
||
|
||
npairs = 50
|
||
|
||
same_folder_pairs = content_list_read['same_folder_pairs']
|
||
cross_folder_pairs = content_list_read['cross_folder_pairs']
|
||
|
||
npairs = min((len(same_folder_pairs), len(cross_folder_pairs)))
|
||
|
||
|
||
same_pairs = same_folder_pairs[:npairs]
|
||
cross_pairs = cross_folder_pairs[:npairs]
|
||
|
||
same_pairs_similarity = []
|
||
for i in range(len(same_pairs)):
|
||
images_a = [osp.join(conf.test_val, img) for img in same_pairs[i][0]]
|
||
images_b = [osp.join(conf.test_val, img) for img in same_pairs[i][1]]
|
||
|
||
feats_a = inference_image(images_a, conf.test_transform, model, conf.device)
|
||
feats_b = inference_image(images_b, conf.test_transform, model, conf.device)
|
||
# matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
|
||
matrix = 1 - cdist(feats_a, feats_b, 'cosine')
|
||
|
||
feats_am = np.mean(feats_a, axis=0, keepdims=True)
|
||
feats_bm = np.mean(feats_b, axis=0, keepdims=True)
|
||
matrixm = 1- np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
|
||
|
||
same_pairs_similarity.append(np.mean(matrix))
|
||
|
||
'''保存相同 Barcode 图像对'''
|
||
# foldi = os.path.join('./result/same', f'{i}')
|
||
# if os.path.exists(foldi):
|
||
# shutil.rmtree(foldi)
|
||
# os.makedirs(foldi)
|
||
# else:
|
||
# os.makedirs(foldi)
|
||
# for ipt in range(len(images_a)):
|
||
# source_path = images_a[ipt]
|
||
# destination_path = os.path.join(foldi, f'a_{ipt}.png')
|
||
# shutil.copy2(source_path, destination_path)
|
||
# for ipt in range(len(images_b)):
|
||
# source_path = images_b[ipt]
|
||
# destination_path = os.path.join(foldi, f'b_{ipt}.png')
|
||
# shutil.copy2(source_path, destination_path)
|
||
|
||
cross_pairs_similarity = []
|
||
for i in range(len(cross_pairs)):
|
||
images_a = [osp.join(conf.test_val, img) for img in cross_pairs[i][0]]
|
||
images_b = [osp.join(conf.test_val, img) for img in cross_pairs[i][1]]
|
||
|
||
feats_a = inference_image(images_a, conf.test_transform, model, conf.device)
|
||
feats_b = inference_image(images_b, conf.test_transform, model, conf.device)
|
||
# matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
|
||
matrix = 1 - cdist(feats_a, feats_b, 'cosine')
|
||
|
||
feats_am = np.mean(feats_a, axis=0, keepdims=True)
|
||
feats_bm = np.mean(feats_b, axis=0, keepdims=True)
|
||
matrixm = 1- np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
|
||
|
||
cross_pairs_similarity.append(np.mean(matrix))
|
||
|
||
'''保存不同 Barcode 图像对'''
|
||
# foldi = os.path.join('./result/cross', f'{i}')
|
||
# if os.path.exists(foldi):
|
||
# shutil.rmtree(foldi)
|
||
# os.makedirs(foldi)
|
||
# else:
|
||
# os.makedirs(foldi)
|
||
# for ipt in range(len(images_a)):
|
||
# source_path = images_a[ipt]
|
||
# destination_path = os.path.join(foldi, f'a_{ipt}.png')
|
||
# shutil.copy2(source_path, destination_path)
|
||
# for ipt in range(len(images_b)):
|
||
# source_path = images_b[ipt]
|
||
# destination_path = os.path.join(foldi, f'b_{ipt}.png')
|
||
# shutil.copy2(source_path, destination_path)
|
||
|
||
|
||
Thresh = np.linspace(-0.2, 1, 100)
|
||
|
||
Same = np.array(same_pairs_similarity)
|
||
Cross = np.array(cross_pairs_similarity)
|
||
|
||
fig, axs = plt.subplots(2, 1)
|
||
axs[0].hist(Same, bins=60, edgecolor='black')
|
||
axs[0].set_xlim([-0.2, 1])
|
||
axs[0].set_title('Same Barcode')
|
||
|
||
axs[1].hist(Cross, bins=60, edgecolor='black')
|
||
axs[1].set_xlim([-0.2, 1])
|
||
axs[1].set_title('Cross Barcode')
|
||
|
||
TPFN = len(Same)
|
||
TNFP = len(Cross)
|
||
Recall_Pos, Recall_Neg = [], []
|
||
Precision_Pos, Precision_Neg = [], []
|
||
Correct = []
|
||
for th in Thresh:
|
||
TP = np.sum(Same > th)
|
||
FN = TPFN - TP
|
||
TN = np.sum(Cross < th)
|
||
FP = TNFP - TN
|
||
|
||
Recall_Pos.append(TP/TPFN)
|
||
Recall_Neg.append(TN/TNFP)
|
||
Precision_Pos.append(TP/(TP+FP))
|
||
Precision_Neg.append(TN/(TN+FN))
|
||
Correct.append((TN+TP)/(TPFN+TNFP))
|
||
|
||
fig, ax = plt.subplots()
|
||
ax.plot(Thresh, Correct, 'r', label='Correct: (TN+TP)/(TPFN+TNFP)')
|
||
ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
|
||
ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP')
|
||
ax.plot(Thresh, Precision_Pos, 'c', label='Precision_Pos: TP/(TP+FP)')
|
||
ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)')
|
||
|
||
ax.set_xlim([0, 1])
|
||
ax.set_ylim([0, 1])
|
||
ax.grid(True)
|
||
ax.set_title('PrecisePos & PreciseNeg')
|
||
ax.legend()
|
||
plt.show()
|
||
|
||
print("Haved done!!!")
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
|
||
# Network Setup
|
||
if conf.testbackbone == 'resnet18':
|
||
# model = ResIRSE(img_size, embedding_size, conf.drop_ratio).to(device)
|
||
model = resnet18().to(device)
|
||
# elif conf.testbackbone == 'resnet34':
|
||
# model = resnet34().to(device)
|
||
# elif conf.testbackbone == 'resnet50':
|
||
# model = resnet50().to(device)
|
||
# elif conf.testbackbone == 'mobilevit_s':
|
||
# model = mobilevit_s().to(device)
|
||
# elif conf.testbackbone == 'mobilenetv3':
|
||
# model = MobileNetV3_Small().to(device)
|
||
# elif conf.testbackbone == 'mobilenet_v1':
|
||
# model = mobilenet_v1().to(device)
|
||
# elif conf.testbackbone == 'PPLCNET_x1_0':
|
||
# model = PPLCNET_x1_0().to(device)
|
||
# elif conf.testbackbone == 'PPLCNET_x0_5':
|
||
# model = PPLCNET_x0_5().to(device)
|
||
# elif conf.backbone == 'PPLCNET_x2_5':
|
||
# model = PPLCNET_x2_5().to(device)
|
||
# elif conf.testbackbone == 'mobilenet_v2':
|
||
# model = mobilenet_v2().to(device)
|
||
# elif conf.testbackbone == 'resnet14':
|
||
# model = resnet14().to(device)
|
||
else:
|
||
raise ValueError('Have not model {}'.format(conf.backbone))
|
||
|
||
print('load model {} '.format(conf.testbackbone))
|
||
# model = nn.DataParallel(model).to(conf.device)
|
||
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
|
||
model.eval()
|
||
if not conf.group_test:
|
||
images = unique_image(conf.test_list)
|
||
images = [osp.join(conf.test_val, img) for img in images]
|
||
|
||
groups = group_image(images, conf.test_batch_size) ##根据batch_size取图片
|
||
|
||
feature_dict = dict()
|
||
for group in groups:
|
||
d = featurize(group, conf.test_transform, model, conf.device)
|
||
feature_dict.update(d)
|
||
# print('feature_dict', feature_dict)
|
||
accuracy, threshold = compute_accuracy(feature_dict, conf.test_list, conf.test_val)
|
||
|
||
print(
|
||
f"Test Model: {conf.test_model}\n"
|
||
f"Accuracy: {accuracy:.3f}\n"
|
||
f"Threshold: {threshold:.3f}\n"
|
||
)
|
||
elif conf.group_test:
|
||
"""
|
||
conf.test_val: 测试数据集地址
|
||
conf.test_group_json:测试数据分组配置文件
|
||
"""
|
||
filename = conf.test_group_json
|
||
|
||
filename = "../cl/images_1.json"
|
||
with open(filename, 'r', encoding='utf-8') as file:
|
||
content_list_read = json.load(file)
|
||
|
||
|
||
compute_contrast_accuracy(content_list_read)
|
||
|
||
|
||
|
||
|
||
|
||
# =============================================================================
|
||
# Similarity, Label = compute_group_accuracy(content_list_read)
|
||
# print('allSimilarity >> {}'.format(Similarity))
|
||
# print('allLabel >> {}'.format(Label))
|
||
# compute_accuracy_recall(np.array(Similarity), np.array(Label))
|
||
# # compute_group_accuracy(data_loaded)
|
||
#
|
||
# =============================================================================
|