Files
ieemoo-ai-review/detecttracking/contrast/feat_extract/inference.py
2025-01-22 13:16:44 +08:00

548 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
@author: LiChen
"""
import numpy as np
import torch
from pathlib import Path
from utils.config import config as cfg
curpath = Path(__file__).resolve().parents[0]
class FeatsInterface:
def __init__(self, resnetModel=None):
self.device = cfg.device
self.transform = cfg.test_transform
self.batch_size = cfg.batch_size
self.embedding_size = cfg.embedding_size
assert resnetModel is not None, "resnetModel is None"
self.model = resnetModel
print(f"Model type: {type(self.model)}")
def inference(self, images, detections=None):
'''
如果是BGR需要转变为RGB格式
'''
if isinstance(images, np.ndarray):
imgs, features = self.inference_image(images, detections)
return imgs, features
batch_patches = []
patches = []
for i, img in enumerate(images):
img = img.copy()
patch = self.transform(img)
if str(self.device) != "cpu":
# patch = patch.to(device=self.device).half()
patch = patch.to(device=self.device)
else:
patch = patch.to(device=self.device)
patches.append(patch)
if (i + 1) % self.batch_size == 0:
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
patches = []
if len(patches):
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
features = np.zeros((0, self.embedding_size))
for patches in batch_patches:
pred = self.model(patches)
pred[torch.isinf(pred)] = 1.0
feat = pred.cpu().data.numpy()
features = np.vstack((features, feat))
return features
def inference_image(self, image, detections):
H, W, _ = np.shape(image)
batch_patches = []
patches = []
imgs = []
for d in range(np.size(detections, 0)):
tlbr = detections[d, :4].astype(np.int_)
tlbr[0] = max(0, tlbr[0])
tlbr[1] = max(0, tlbr[1])
tlbr[2] = min(W - 1, tlbr[2])
tlbr[3] = min(H - 1, tlbr[3])
img = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
imgs.append(img)
img1 = img[:, :, ::-1].copy() # the model expects RGB inputs
patch = self.transform(img1)
# patch = patch.to(device=self.device).half()
if str(self.device) != "cpu":
# patch = patch.to(device=self.device).half()
patch = patch.to(device=self.device)
else:
patch = patch.to(device=self.device)
patches.append(patch)
if (d + 1) % self.batch_size == 0:
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
patches = []
if len(patches):
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
features = np.zeros((0, self.embedding_size))
for patches in batch_patches:
pred = self.model(patches)
pred[torch.isinf(pred)] = 1.0
feat = pred.cpu().data.numpy()
features = np.vstack((features, feat))
return imgs, features
# def unique_image(pair_list) -> set:
# """Return unique image path in pair_list.txt"""
# with open(pair_list, 'r') as fd:
# pairs = fd.readlines()
# unique = set()
# for pair in pairs:
# id1, id2, _ = pair.split()
# unique.add(id1)
# unique.add(id2)
# return unique
#
#
# def group_image(images: set, batch) -> list:
# """Group image paths by batch size"""
# images = list(images)
# size = len(images)
# res = []
# for i in range(0, size, batch):
# end = min(batch + i, size)
# res.append(images[i: end])
# return res
#
#
# def _preprocess(images: list, transform) -> torch.Tensor:
# res = []
# for img in images:
# im = Image.open(img)
# im = transform(im)
# res.append(im)
# # data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
# # data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
# data = torch.stack(res)
# return data
#
#
# def test_preprocess(images: list, transform) -> torch.Tensor:
# res = []
# for img in images:
# im = Image.open(img)
# im = transform(im)
# res.append(im)
# # data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
# # data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
# data = torch.stack(res)
# return data
#
#
# def featurize(images: list, transform, net, device, train=False) -> dict:
# """featurize each image and save into a dictionary
# Args:
# images: image paths
# transform: test transform
# net: pretrained model
# device: cpu or cuda
# Returns:
# Dict (key: imagePath, value: feature)
# """
# if train:
# data = _preprocess(images, transform)
# data = data.to(device)
# net = net.to(device)
# with torch.no_grad():
# features = net(data)
# res = {img: feature for (img, feature) in zip(images, features)}
# else:
# data = test_preprocess(images, transform)
# data = data.to(device)
# net = net.to(device)
# with torch.no_grad():
# features = net(data)
# res = {img: feature for (img, feature) in zip(images, features)}
# return res
#
#
# # def inference_image(images: list, transform, net, device, bs=16, embedding_size=256) -> dict:
# # batch_patches = []
# # patches = []
# # for d, img in enumerate(images):
# # img = Image.open(img)
# # patch = transform(img)
#
# # if str(device) != "cpu":
# # patch = patch.to(device).half()
# # else:
# # patch = patch.to(device)
#
# # patches.append(patch)
# # if (d + 1) % bs == 0:
# # patches = torch.stack(patches, dim=0)
# # batch_patches.append(patches)
# # patches = []
#
# # if len(patches):
# # patches = torch.stack(patches, dim=0)
# # batch_patches.append(patches)
#
# # features = np.zeros((0, embedding_size), dtype=np.float32)
# # for patches in batch_patches:
# # pred = net(patches)
# # pred[torch.isinf(pred)] = 1.0
# # feat = pred.cpu().data.numpy()
# # features = np.vstack((features, feat))
#
#
# # return features
#
#
# def featurize_1(images: list, transform, net, device, train=False) -> dict:
# """featurize each image and save into a dictionary
# Args:
# images: image paths
# transform: test transform
# net: pretrained model
# device: cpu or cuda
# Returns:
# Dict (key: imagePath, value: feature)
# """
#
# data = test_preprocess(images, transform)
# data = data.to(device)
# net = net.to(device)
# with torch.no_grad():
# features = net(data).data.numpy()
#
# return features
#
#
# def cosin_metric(x1, x2):
# return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
#
#
# def threshold_search(y_score, y_true):
# y_score = np.asarray(y_score)
# y_true = np.asarray(y_true)
# best_acc = 0
# best_th = 0
# for i in range(len(y_score)):
# th = y_score[i]
# y_test = (y_score >= th)
# acc = np.mean((y_test == y_true).astype(int))
# if acc > best_acc:
# best_acc = acc
# best_th = th
# return best_acc, best_th
#
#
# def showgrid(recall, recall_TN, PrecisePos, PreciseNeg):
# x = np.linspace(start=-1.0, stop=1.0, num=50, endpoint=True).tolist()
# plt.figure(figsize=(10, 6))
# plt.plot(x, recall, color='red', label='recall')
# plt.plot(x, recall_TN, color='black', label='recall_TN')
# plt.plot(x, PrecisePos, color='blue', label='PrecisePos')
# plt.plot(x, PreciseNeg, color='green', label='PreciseNeg')
# plt.legend()
# plt.xlabel('threshold')
# # plt.ylabel('Similarity')
# plt.grid(True, linestyle='--', alpha=0.5)
# plt.savefig('accuracy_recall_grid.png')
# plt.show()
# plt.close()
#
#
# def compute_accuracy_recall(score, labels):
# th = 0.1
# squence = np.linspace(-1, 1, num=50)
# # squence = [0.4]
# recall, PrecisePos, PreciseNeg, recall_TN = [], [], [], []
# for th in squence:
# t_score = (score > th)
# t_labels = (labels == 1)
# # print(t_score)
# # print(t_labels)
# TP = np.sum(np.logical_and(t_score, t_labels))
# FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
# f_score = (score < th)
# f_labels = (labels == 0)
# TN = np.sum(np.logical_and(f_score, f_labels))
# FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
# print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
#
# PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
# PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
# recall.append(0 if TP == 0 else TP / (TP + FN))
# recall_TN.append(0 if TN == 0 else TN / (TN + FP))
# showgrid(recall, recall_TN, PrecisePos, PreciseNeg)
#
#
# def compute_accuracy(feature_dict, pair_list, test_root):
# with open(pair_list, 'r') as f:
# pairs = f.readlines()
#
# similarities = []
# labels = []
# for pair in pairs:
# img1, img2, label = pair.split()
# img1 = osp.join(test_root, img1)
# img2 = osp.join(test_root, img2)
# feature1 = feature_dict[img1].cpu().numpy()
# feature2 = feature_dict[img2].cpu().numpy()
# label = int(label)
#
# similarity = cosin_metric(feature1, feature2)
# similarities.append(similarity)
# labels.append(label)
#
# accuracy, threshold = threshold_search(similarities, labels)
# # print('similarities >> {}'.format(similarities))
# # print('labels >> {}'.format(labels))
# compute_accuracy_recall(np.array(similarities), np.array(labels))
# return accuracy, threshold
# def deal_group_pair(pairList1, pairList2):
# allsimilarity = []
# one_similarity = []
# for pair1 in pairList1:
# for pair2 in pairList2:
# similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
# one_similarity.append(similarity)
# allsimilarity.append(max(one_similarity)) # 最大值
# # allsimilarity.append(sum(one_similarity)/len(one_similarity)) # 均值
# # allsimilarity.append(statistics.median(one_similarity)) # 中位数
# # print(allsimilarity)
# # print(labels)
# return allsimilarity
# def compute_group_accuracy(content_list_read):
# allSimilarity, allLabel = [], []
# for data_loaded in content_list_read:
# one_group_list = []
# for i in range(2):
# images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
# group = group_image(images, conf.test_batch_size)
# d = featurize(group[0], conf.test_transform, model, conf.device)
# one_group_list.append(d.values())
# similarity = deal_group_pair(one_group_list[0], one_group_list[1])
# allLabel.append(data_loaded[-1])
# allSimilarity.extend(similarity)
# # print(allSimilarity)
# # print(allLabel)
# return allSimilarity, allLabel
# def compute_contrast_accuracy(content_list_read):
# npairs = 50
#
# same_folder_pairs = content_list_read['same_folder_pairs']
# cross_folder_pairs = content_list_read['cross_folder_pairs']
#
# npairs = min((len(same_folder_pairs), len(cross_folder_pairs)))
#
# Encoder = FeatsInterface(conf)
#
# same_pairs = same_folder_pairs[:npairs]
# cross_pairs = cross_folder_pairs[:npairs]
#
# same_pairs_similarity = []
# for i in range(len(same_pairs)):
# images_a = [osp.join(conf.test_val, img) for img in same_pairs[i][0]]
# images_b = [osp.join(conf.test_val, img) for img in same_pairs[i][1]]
#
# feats_a = Encoder.inference(images_a)
# feats_b = Encoder.inference(images_b)
# # matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
# matrix = 1 - cdist(feats_a, feats_b, 'cosine')
#
# feats_am = np.mean(feats_a, axis=0, keepdims=True)
# feats_bm = np.mean(feats_b, axis=0, keepdims=True)
# matrixm = 1 - np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
#
# same_pairs_similarity.append(np.mean(matrix))
#
# '''保存相同 Barcode 图像对'''
# # foldi = os.path.join('./result/same', f'{i}')
# # if os.path.exists(foldi):
# # shutil.rmtree(foldi)
# # os.makedirs(foldi)
# # else:
# # os.makedirs(foldi)
# # for ipt in range(len(images_a)):
# # source_path = images_a[ipt]
# # destination_path = os.path.join(foldi, f'a_{ipt}.png')
# # shutil.copy2(source_path, destination_path)
# # for ipt in range(len(images_b)):
# # source_path = images_b[ipt]
# # destination_path = os.path.join(foldi, f'b_{ipt}.png')
# # shutil.copy2(source_path, destination_path)
#
# cross_pairs_similarity = []
# for i in range(len(cross_pairs)):
# images_a = [osp.join(conf.test_val, img) for img in cross_pairs[i][0]]
# images_b = [osp.join(conf.test_val, img) for img in cross_pairs[i][1]]
#
# feats_a = Encoder.inference(images_a)
# feats_b = Encoder.inference(images_b)
# # matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
# matrix = 1 - cdist(feats_a, feats_b, 'cosine')
#
# feats_am = np.mean(feats_a, axis=0, keepdims=True)
# feats_bm = np.mean(feats_b, axis=0, keepdims=True)
# matrixm = 1 - np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
#
# cross_pairs_similarity.append(np.mean(matrix))
#
# '''保存不同 Barcode 图像对'''
# # foldi = os.path.join('./result/cross', f'{i}')
# # if os.path.exists(foldi):
# # shutil.rmtree(foldi)
# # os.makedirs(foldi)
# # else:
# # os.makedirs(foldi)
# # for ipt in range(len(images_a)):
# # source_path = images_a[ipt]
# # destination_path = os.path.join(foldi, f'a_{ipt}.png')
# # shutil.copy2(source_path, destination_path)
# # for ipt in range(len(images_b)):
# # source_path = images_b[ipt]
# # destination_path = os.path.join(foldi, f'b_{ipt}.png')
# # shutil.copy2(source_path, destination_path)
#
# Thresh = np.linspace(-0.2, 1, 100)
#
# Same = np.array(same_pairs_similarity)
# Cross = np.array(cross_pairs_similarity)
#
# fig, axs = plt.subplots(2, 1)
# axs[0].hist(Same, bins=60, edgecolor='black')
# axs[0].set_xlim([-0.2, 1])
# axs[0].set_title('Same Barcode')
#
# axs[1].hist(Cross, bins=60, edgecolor='black')
# axs[1].set_xlim([-0.2, 1])
# axs[1].set_title('Cross Barcode')
#
# TPFN = len(Same)
# TNFP = len(Cross)
# Recall_Pos, Recall_Neg = [], []
# Precision_Pos, Precision_Neg = [], []
# Correct = []
# for th in Thresh:
# TP = np.sum(Same > th)
# FN = TPFN - TP
# TN = np.sum(Cross < th)
# FP = TNFP - TN
#
# Recall_Pos.append(TP / TPFN)
# Recall_Neg.append(TN / TNFP)
# Precision_Pos.append(TP / (TP + FP))
# Precision_Neg.append(TN / (TN + FN))
# Correct.append((TN + TP) / (TPFN + TNFP))
#
# fig, ax = plt.subplots()
# ax.plot(Thresh, Correct, 'r', label='Correct: (TN+TP)/(TPFN+TNFP)')
# ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
# ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP')
# ax.plot(Thresh, Precision_Pos, 'c', label='Precision_Pos: TP/(TP+FP)')
# ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)')
#
# ax.set_xlim([0, 1])
# ax.set_ylim([0, 1])
# ax.grid(True)
# ax.set_title('PrecisePos & PreciseNeg')
# ax.legend()
# plt.show()
#
# print("Haved done!!!")
#
#
# if __name__ == '__main__':
#
# # Network Setup
# if conf.testbackbone == 'resnet18':
# # model = ResIRSE(conf.img_size, conf.embedding_size, conf.drop_ratio).to(conf.device)
# model = resnet18().to(conf.device)
# # elif conf.testbackbone == 'resnet34':
# # model = resnet34().to(conf.device)
# # elif conf.testbackbone == 'resnet50':
# # model = resnet50().to(conf.device)
# # elif conf.testbackbone == 'mobilevit_s':
# # model = mobilevit_s().to(conf.device)
# # elif conf.testbackbone == 'mobilenetv3':
# # model = MobileNetV3_Small().to(conf.device)
# # elif conf.testbackbone == 'mobilenet_v1':
# # model = mobilenet_v1().to(conf.device)
# # elif conf.testbackbone == 'PPLCNET_x1_0':
# # model = PPLCNET_x1_0().to(conf.device)
# # elif conf.testbackbone == 'PPLCNET_x0_5':
# # model = PPLCNET_x0_5().to(conf.device)
# # elif conf.backbone == 'PPLCNET_x2_5':
# # model = PPLCNET_x2_5().to(conf.device)
# # elif conf.testbackbone == 'mobilenet_v2':
# # model = mobilenet_v2().to(conf.device)
# # elif conf.testbackbone == 'resnet14':
# # model = resnet14().to(conf.device)
# else:
# raise ValueError('Have not model {}'.format(conf.backbone))
#
# print('load model {} '.format(conf.testbackbone))
# # model = nn.DataParallel(model).to(conf.device)
# model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
# model.eval()
# if not conf.group_test:
# images = unique_image(conf.test_list)
# images = [osp.join(conf.test_val, img) for img in images]
#
# groups = group_image(images, conf.test_batch_size) ##根据batch_size取图片
#
# feature_dict = dict()
# for group in groups:
# d = featurize(group, conf.test_transform, model, conf.device)
# feature_dict.update(d)
# # print('feature_dict', feature_dict)
# accuracy, threshold = compute_accuracy(feature_dict, conf.test_list, conf.test_val)
#
# print(
# f"Test Model: {conf.test_model}\n"
# f"Accuracy: {accuracy:.3f}\n"
# f"Threshold: {threshold:.3f}\n"
# )
# elif conf.group_test:
# """
# conf.test_val: 测试数据集地址
# conf.test_group_json测试数据分组配置文件
# """
# filename = conf.test_group_json
#
# filename = "../cl/images_1.json"
# with open(filename, 'r', encoding='utf-8') as file:
# content_list_read = json.load(file)
#
# compute_contrast_accuracy(content_list_read)
# =============================================================================
# Similarity, Label = compute_group_accuracy(content_list_read)
# print('allSimilarity >> {}'.format(Similarity))
# print('allLabel >> {}'.format(Label))
# compute_accuracy_recall(np.array(Similarity), np.array(Label))
# # compute_group_accuracy(data_loaded)
#
# =============================================================================