# -*- coding: utf-8 -*- import os.path as osp from typing import Dict, List, Set, Tuple import torch import torch.nn as nn import numpy as np from PIL import Image import json import matplotlib.pyplot as plt # from config import config as conf from tools.dataset import get_transform from tools.image_joint import merge_imgs from tools.getHeatMap import cal_cam from configs import trainer_tools import yaml from datetime import datetime with open('../configs/test.yml', 'r') as f: conf = yaml.load(f, Loader=yaml.FullLoader) # Constants from config embedding_size = conf["base"]["embedding_size"] img_size = conf["transform"]["img_size"] device = conf["base"]["device"] def unique_image(pair_list: str) -> Set[str]: unique_images = set() try: with open(pair_list, 'r') as f: for line in f: line = line.strip() if not line: continue try: img1, img2, _ = line.split() unique_images.update([img1, img2]) except ValueError as e: print(f"Skipping malformed line: {line}") except IOError as e: print(f"Error reading pair list file: {e}") raise return unique_images def group_image(images: Set[str], batch_size: int) -> List[List[str]]: """ Group image paths into batches of specified size. Args: images: Set of image paths to group batch_size: Number of images per batch Returns: List of batches, where each batch is a list of image paths """ image_list = list(images) num_images = len(image_list) batches = [] for i in range(0, num_images, batch_size): batch_end = min(i + batch_size, num_images) batches.append(image_list[i:batch_end]) return batches def _preprocess(images: list, transform) -> torch.Tensor: res = [] for img in images: im = Image.open(img) im = transform(im) res.append(im) # data = torch.cat(res, dim=0) # shape: (batch, 128, 128) # data = data[:, None, :, :] # shape: (batch, 1, 128, 128) data = torch.stack(res) return data def test_preprocess(images: list, transform) -> torch.Tensor: res = [] for img in images: im = Image.open(img) if im.mode == 'RGBA': im = im.convert('RGB') im = transform(im) res.append(im) data = torch.stack(res) return data def featurize( images: List[str], transform: callable, net: nn.Module, device: torch.device, train: bool = False ) -> Dict[str, torch.Tensor]: try: # Select appropriate preprocessing preprocess_fn = _preprocess if train else test_preprocess # Preprocess and move to device data = preprocess_fn(images, transform) data = data.to(device) net = net.to(device) # Extract features with automatic mixed precision with torch.no_grad(): if conf['models']['half']: data = data.half() features = net(data) # Create path-to-feature mapping return {img: feature for img, feature in zip(images, features)} except Exception as e: print(f"Error in feature extraction: {e}") raise def cosin_metric(x1, x2): return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2)) def threshold_search(y_score, y_true): y_score = np.asarray(y_score) y_true = np.asarray(y_true) best_acc = 0 best_th = 0 for i in range(len(y_score)): th = y_score[i] y_test = (y_score >= th) acc = np.mean((y_test == y_true).astype(int)) if acc > best_acc: best_acc = acc best_th = th return best_acc, best_th def showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct): x = np.linspace(start=-1, stop=1.0, num=100, endpoint=True).tolist() plt.figure(figsize=(10, 6)) plt.plot(x, recall, color='red', label='recall:TP/TPFN') plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP') plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN') plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP') plt.plot(x, Correct, color='m', label='Correct:(TN+TP)/(TPFN+TNFP)') plt.legend() plt.xlabel('threshold') # plt.ylabel('Similarity') plt.grid(True, linestyle='--', alpha=0.5) plt.savefig('grid.png') plt.show() plt.close() def showHist(same, cross): Same = np.array(same) Cross = np.array(cross) fig, axs = plt.subplots(2, 1) axs[0].hist(Same, bins=100, edgecolor='black') axs[0].set_xlim([-1, 1]) axs[0].set_title('Same Barcode') axs[1].hist(Cross, bins=100, edgecolor='black') axs[1].set_xlim([-1, 1]) axs[1].set_title('Cross Barcode') plt.savefig('plot.png') def compute_accuracy_recall(score, labels): th = 0.1 squence = np.linspace(-1, 1, num=100) recall, PrecisePos, PreciseNeg, recall_TN, Correct = [], [], [], [], [] Same = score[:len(score) // 2] Cross = score[len(score) // 2:] for th in squence: t_score = (score > th) t_labels = (labels == 1) TP = np.sum(np.logical_and(t_score, t_labels)) FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels)) f_score = (score < th) f_labels = (labels == 0) TN = np.sum(np.logical_and(f_score, f_labels)) FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels)) # print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN)) PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP)) PreciseNeg.append(0 if TN == 0 else TN / (TN + FN)) recall.append(0 if TP == 0 else TP / (TP + FN)) recall_TN.append(0 if TN == 0 else TN / (TN + FP)) Correct.append(0 if TP == 0 else (TP + TN) / (TP + FP + TN + FN)) print("Threshold:{} PrecisePos:{},recall:{},PreciseNeg:{},recall_TN:{}".format(th, PrecisePos[-1], recall[-1], PreciseNeg[-1], recall_TN[-1])) showHist(Same, Cross) showgrid(recall, recall_TN, PrecisePos, PreciseNeg, Correct) def compute_accuracy( feature_dict: Dict[str, torch.Tensor], cam: cal_cam, ) -> Tuple[float, float]: try: pair_list = conf['data']['test_list'] test_root = conf['data']['test_dir'] with open(pair_list, 'r') as f: pairs = f.readlines() except IOError as e: print(f"Error reading pair list: {e}") raise similarities = [] labels = [] for pair in pairs: pair = pair.strip() if not pair: continue # try: print(f"Processing pair: {pair}") img1, img2, label = pair.split() img1_path = osp.join(test_root, img1) img2_path = osp.join(test_root, img2) # Verify features exist if img1_path not in feature_dict or img2_path not in feature_dict: raise ValueError(f"Missing features for image pair: {img1_path}, {img2_path}") # Get features and compute similarity feat1 = feature_dict[img1_path].cpu().numpy() feat2 = feature_dict[img2_path].cpu().numpy() similarity = cosin_metric(feat1, feat2) print('{} vs {}: {}'.format(img1_path, img2_path, similarity)) if conf['data']['save_image_joint']: merge_imgs(img1_path, img2_path, conf, similarity, label, cam) similarities.append(similarity) labels.append(int(label)) # except Exception as e: # print(f"Skipping invalid pair: {pair}. Error: {e}") # continue # Find optimal threshold and accuracy accuracy, threshold = threshold_search(similarities, labels) compute_accuracy_recall(np.array(similarities), np.array(labels)) return accuracy, threshold def deal_group_pair(pairList1, pairList2): allsimilarity = [] one_similarity = [] for pair1 in pairList1: for pair2 in pairList2: similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy()) one_similarity.append(similarity) allsimilarity.append(max(one_similarity)) # 最大值 # allsimilarity.append(sum(one_similarity) / len(one_similarity)) # 均值 # allsimilarity.append(statistics.median(one_similarity)) # 中位数 # print(allsimilarity) # print(labels) return allsimilarity def compute_group_accuracy(content_list_read): allSimilarity, allLabel = [], [] Same, Cross = [], [] for data_loaded in content_list_read: print(data_loaded) one_group_list = [] try: for i in range(2): images = [osp.join(conf.test_val, img) for img in data_loaded[i]] group = group_image(images, conf.test_batch_size) d = featurize(group[0], conf.test_transform, model, conf.device) one_group_list.append(d.values()) if data_loaded[-1] == '1': similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1])) Same.append(similarity) else: similarity = abs(deal_group_pair(one_group_list[0], one_group_list[1])) Cross.append(similarity) allLabel.append(data_loaded[-1]) allSimilarity.extend(similarity) except Exception as e: continue # print(allSimilarity) # print(allLabel) return allSimilarity, allLabel def init_model(): tr_tools = trainer_tools(conf) backbone_mapping = tr_tools.get_backbone() if conf['models']['backbone'] in backbone_mapping: model = backbone_mapping[conf['models']['backbone']]().to(conf['base']['device']) else: raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']})) print('load model {} '.format(conf['models']['backbone'])) if torch.cuda.device_count() > 1 and conf['base']['distributed']: model = nn.DataParallel(model).to(conf['base']['device']) ###############正常模型加载################ model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device'])) ####################################### ####### 对比学习模型临时运用### # state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device']) # new_state_dict = {} # for k, v in state_dict.items(): # new_key = k.replace("module.base_model.", "module.") # new_state_dict[new_key] = v # model.load_state_dict(new_state_dict, strict=False) ########################### if conf['models']['half']: model.half() first_param_dtype = next(model.parameters()).dtype print("模型的第一个参数的数据类型: {}".format(first_param_dtype)) else: try: model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device'])) except: state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device']) new_state_dict = {} for k, v in state_dict.items(): new_key = k.replace("module.", "") new_state_dict[new_key] = v model.load_state_dict(new_state_dict, strict=False) if conf['models']['half']: model.half() first_param_dtype = next(model.parameters()).dtype print("模型的第一个参数的数据类型: {}".format(first_param_dtype)) return model if __name__ == '__main__': model = init_model() model.eval() cam = cal_cam(model, conf) if not conf['data']['group_test']: images = unique_image(conf['data']['test_list']) images = [osp.join(conf['data']['test_dir'], img) for img in images] groups = group_image(images, conf['data']['test_batch_size']) # 根据batch_size取图片 feature_dict = dict() _, test_transform = get_transform(conf) for group in groups: d = featurize(group, test_transform, model, conf['base']['device']) feature_dict.update(d) accuracy, threshold = compute_accuracy(feature_dict, cam) print( "Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold) ) elif conf['data']['group_test']: filename = conf['data']['test_group_json'] with open(filename, 'r', encoding='utf-8') as file: content_list_read = json.load(file) Similarity, Label = compute_group_accuracy(content_list_read) compute_accuracy_recall(np.array(Similarity), np.array(Label)) # compute_group_accuracy(data_loaded)