# -*- coding: UTF-8 -*- import os import torch from torchvision import models import torch.nn as nn import torchvision.transforms as tfs import numpy as np import matplotlib.pyplot as plt from PIL import Image import cv2 # from tools.config import cfg # from comparative.tools.initmodel import initSimilarityModel import yaml from dataset import get_transform class cal_cam(nn.Module): def __init__(self, model, conf): super(cal_cam, self).__init__() self.conf = conf self.device = self.conf['base']['device'] self.model = model self.model.to(self.device) # 要求梯度的层 self.feature_layer = conf['heatmap']['feature_layer'] # 记录梯度 self.gradient = [] # 记录输出的特征图 self.output = [] _, self.transform = get_transform(self.conf) def get_conf(self, yaml_pth): with open(yaml_pth, 'r') as f: conf = yaml.load(f, Loader=yaml.FullLoader) return conf def save_grad(self, grad): self.gradient.append(grad) def get_grad(self): return self.gradient[-1].cpu().data def get_feature(self): return self.output[-1][0] def process_img(self, input): input = self.transform(input) input = input.unsqueeze(0) return input # 计算最后一个卷积层的梯度,输出梯度和最后一个卷积层的特征图 def getGrad(self, input_): self.gradient = [] # 清除之前的梯度 self.output = [] # 清除之前的特征图 # print(f"cuda.memory_allocated 1 {torch.cuda.memory_allocated()/ (1024 ** 3)}G") input_ = input_.to(self.device).requires_grad_(True) num = 1 for name, module in self.model._modules.items(): # print(f'module_name: {name}') # print(f'module: {module}') if (num == 1): input = module(input_) num = num + 1 continue # 是待提取特征图的层 if (name == self.feature_layer): input = module(input) input.register_hook(self.save_grad) self.output.append([input]) # 马上要到全连接层了 elif (name == "avgpool"): input = module(input) input = input.reshape(input.shape[0], -1) # 普通的层 else: input = module(input) # print(f"cuda.memory_allocated 2 {torch.cuda.memory_allocated() / (1024 ** 3)}G") # 到这里input就是最后全连接层的输出了 index = torch.max(input, dim=-1)[1] one_hot = torch.zeros((1, input.shape[-1]), dtype=torch.float32) one_hot[0][index] = 1 confidenct = one_hot * input.cpu() confidenct = torch.sum(confidenct, dim=-1).requires_grad_(True) # print(f"cuda.memory_allocated 3 {torch.cuda.memory_allocated() / (1024 ** 3)}G") # 清除之前的所有梯度 self.model.zero_grad() # 反向传播获取梯度 grad_output = torch.ones_like(confidenct) confidenct.backward(grad_output) # 获取特征图的梯度 grad_val = self.get_grad() feature = self.get_feature() # print(f"cuda.memory_allocated 4 {torch.cuda.memory_allocated() / (1024 ** 3)}G") return grad_val, feature, input_.grad # 计算CAM def getCam(self, grad_val, feature): # 对特征图的每个通道进行全局池化 alpha = torch.mean(grad_val, dim=(2, 3)).cpu() feature = feature.cpu() # 将池化后的结果和相应通道特征图相乘 cam = torch.zeros((feature.shape[2], feature.shape[3]), dtype=torch.float32) for idx in range(alpha.shape[1]): cam = cam + alpha[0][idx] * feature[0][idx] # 进行ReLU操作 cam = np.maximum(cam.detach().numpy(), 0) # plt.imshow(cam) # plt.colorbar() # plt.savefig("cam.jpg") # 将cam区域放大到输入图片大小 cam_ = cv2.resize(cam, (224, 224)) cam_ = cam_ - np.min(cam_) cam_ = cam_ / np.max(cam_) # plt.imshow(cam_) # plt.savefig("cam_.jpg") cam = torch.from_numpy(cam) return cam, cam_ def show_img(self, cam_, img, heatmap_save_pth, imgname): heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET) cam_img = 0.3 * heatmap + 0.7 * np.float32(img) # cv2.imwrite("img.jpg", cam_img) cv2.imwrite(os.sep.join([heatmap_save_pth, imgname]), cam_img) def get_hot_map(self, img_pth): img = Image.open(img_pth) img = img.resize((224, 224)) input = self.process_img(img) grad_val, feature, input_grad = self.getGrad(input) cam, cam_ = self.getCam(grad_val, feature) heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET) cam_img = 0.3 * heatmap + 0.7 * np.float32(img) cam_img = Image.fromarray(np.uint8(cam_img)) return cam_img # def __call__(self, img_root, heatmap_save_pth): # for imgname in os.listdir(img_root): # img = Image.open(os.sep.join([img_root, imgname])) # img = img.resize((224, 224)) # # plt.imshow(img) # # plt.savefig("airplane.jpg") # input = self.process_img(img) # grad_val, feature, input_grad = self.getGrad(input) # cam, cam_ = self.getCam(grad_val, feature) # self.show_img(cam_, img, heatmap_save_pth, imgname) # return cam if __name__ == "__main__": cam = cal_cam() img_root = "test_img/" heatmap_save_pth = "heatmap_result" cam(img_root, heatmap_save_pth)