from configs.utils import trainer_tools from test_ori import group_image, featurize, cosin_metric from tools.dataset import get_transform from tools.getHeatMap import cal_cam from tools.image_joint import merge_imgs import torch.nn as nn import torch from collections import ChainMap import yaml import os class SimilarAnalysis: def __init__(self): with open('../configs/similar_analysis.yml', 'r') as f: self.conf = yaml.load(f, Loader=yaml.FullLoader) self.model = self.initialize_model(self.conf) _, self.test_transform = get_transform(self.conf) self.cam = cal_cam(self.model, self.conf) def initialize_model(self, conf): """初始化模型和度量方法""" tr_tools = trainer_tools(conf) backbone_mapping = tr_tools.get_backbone() print('model_path {}'.format(conf['models']['model_path'])) if conf['models']['backbone'] in backbone_mapping: model = backbone_mapping[conf['models']['backbone']]() else: raise ValueError('不支持该模型: {}'.format({conf['models']['backbone']})) try: model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device'])) except: state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device']) new_state_dict = {} for k, v in state_dict.items(): new_key = k.replace("module.", "") new_state_dict[new_key] = v model.load_state_dict(new_state_dict, strict=False) return model.eval() def get_feature(self, img_pth): group = group_image([img_pth], self.conf['data']['val_batch_size']) feature = featurize(group[0], self.test_transform, self.model, self.conf['base']['device']) return feature def get_similarity(self, feature_dict1, feature_dict2): similarity = cosin_metric(feature_dict1, feature_dict2) print(f"Similarity: {similarity}") return similarity def get_feature_map(self, all_imgs): feature_dicts = {} for img_pth in all_imgs: print(f"Processing {img_pth}") feature_dict = self.get_feature(img_pth) feature_dicts = dict(ChainMap(feature_dict, feature_dicts)) return feature_dicts def get_image_map(self): all_compare_img = [] for root, dirs, files in os.walk(self.conf['data']['data_dir']): if len(dirs) == 2: dir_pth_1 = os.sep.join([root, dirs[0]]) dir_pth_2 = os.sep.join([root, dirs[1]]) for img_name_1 in os.listdir(dir_pth_1): for img_name_2 in os.listdir(dir_pth_2): all_compare_img.append((os.sep.join([dir_pth_1, img_name_1]), os.sep.join([dir_pth_2, img_name_2]))) return all_compare_img def create_total_feature(self): all_imgs = [] for root, dirs, files in os.walk(self.conf['data']['data_dir']): if len(dirs) == 2: for dir_name in dirs: dir_pth = os.sep.join([root, dir_name]) for img_name in os.listdir(dir_pth): all_imgs.append(os.sep.join([dir_pth, img_name])) return all_imgs def get_contrast_result(self, feature_dicts, all_compare_img): for img_pth1, img_pth2 in all_compare_img: feature_dict1 = feature_dicts[img_pth1] feature_dict2 = feature_dicts[img_pth2] similarity = self.get_similarity(feature_dict1.cpu().numpy(), feature_dict2.cpu().numpy()) dir_name = img_pth1.split('/')[-3] save_path = os.sep.join([self.conf['data']['image_joint_pth'], dir_name]) if similarity > 0.7: merge_imgs(img_pth1, img_pth2, self.conf, similarity, label=None, cam=self.cam, save_path=save_path) print(similarity) if __name__ == '__main__': ana = SimilarAnalysis() all_imgs = ana.create_total_feature() feature_dicts = ana.get_feature_map(all_imgs) all_compare_img = ana.get_image_map() ana.get_contrast_result(feature_dicts, all_compare_img)