训练数据前置处理与提升训练效率
This commit is contained in:
164
tools/getHeatMap.py
Normal file
164
tools/getHeatMap.py
Normal file
@ -0,0 +1,164 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torchvision import models
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as tfs
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import cv2
|
||||
# from tools.config import cfg
|
||||
# from comparative.tools.initmodel import initSimilarityModel
|
||||
import yaml
|
||||
from dataset import get_transform
|
||||
|
||||
|
||||
class cal_cam(nn.Module):
|
||||
def __init__(self, model, conf):
|
||||
super(cal_cam, self).__init__()
|
||||
self.conf = conf
|
||||
self.device = self.conf['base']['device']
|
||||
|
||||
self.model = model
|
||||
self.model.to(self.device)
|
||||
|
||||
# 要求梯度的层
|
||||
self.feature_layer = conf['heatmap']['feature_layer']
|
||||
# 记录梯度
|
||||
self.gradient = []
|
||||
# 记录输出的特征图
|
||||
self.output = []
|
||||
_, self.transform = get_transform(self.conf)
|
||||
|
||||
def get_conf(self, yaml_pth):
|
||||
with open(yaml_pth, 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return conf
|
||||
|
||||
def save_grad(self, grad):
|
||||
self.gradient.append(grad)
|
||||
|
||||
def get_grad(self):
|
||||
return self.gradient[-1].cpu().data
|
||||
|
||||
def get_feature(self):
|
||||
return self.output[-1][0]
|
||||
|
||||
def process_img(self, input):
|
||||
input = self.transform(input)
|
||||
input = input.unsqueeze(0)
|
||||
return input
|
||||
|
||||
# 计算最后一个卷积层的梯度,输出梯度和最后一个卷积层的特征图
|
||||
def getGrad(self, input_):
|
||||
self.gradient = [] # 清除之前的梯度
|
||||
self.output = [] # 清除之前的特征图
|
||||
# print(f"cuda.memory_allocated 1 {torch.cuda.memory_allocated()/ (1024 ** 3)}G")
|
||||
input_ = input_.to(self.device).requires_grad_(True)
|
||||
num = 1
|
||||
for name, module in self.model._modules.items():
|
||||
# print(f'module_name: {name}')
|
||||
# print(f'module: {module}')
|
||||
if (num == 1):
|
||||
input = module(input_)
|
||||
num = num + 1
|
||||
continue
|
||||
# 是待提取特征图的层
|
||||
if (name == self.feature_layer):
|
||||
input = module(input)
|
||||
input.register_hook(self.save_grad)
|
||||
self.output.append([input])
|
||||
# 马上要到全连接层了
|
||||
elif (name == "avgpool"):
|
||||
input = module(input)
|
||||
input = input.reshape(input.shape[0], -1)
|
||||
# 普通的层
|
||||
else:
|
||||
input = module(input)
|
||||
|
||||
# print(f"cuda.memory_allocated 2 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||
# 到这里input就是最后全连接层的输出了
|
||||
index = torch.max(input, dim=-1)[1]
|
||||
one_hot = torch.zeros((1, input.shape[-1]), dtype=torch.float32)
|
||||
one_hot[0][index] = 1
|
||||
confidenct = one_hot * input.cpu()
|
||||
confidenct = torch.sum(confidenct, dim=-1).requires_grad_(True)
|
||||
|
||||
# print(f"cuda.memory_allocated 3 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||
# 清除之前的所有梯度
|
||||
self.model.zero_grad()
|
||||
# 反向传播获取梯度
|
||||
grad_output = torch.ones_like(confidenct)
|
||||
confidenct.backward(grad_output)
|
||||
# 获取特征图的梯度
|
||||
grad_val = self.get_grad()
|
||||
feature = self.get_feature()
|
||||
|
||||
# print(f"cuda.memory_allocated 4 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||
return grad_val, feature, input_.grad
|
||||
|
||||
# 计算CAM
|
||||
def getCam(self, grad_val, feature):
|
||||
# 对特征图的每个通道进行全局池化
|
||||
alpha = torch.mean(grad_val, dim=(2, 3)).cpu()
|
||||
feature = feature.cpu()
|
||||
# 将池化后的结果和相应通道特征图相乘
|
||||
cam = torch.zeros((feature.shape[2], feature.shape[3]), dtype=torch.float32)
|
||||
for idx in range(alpha.shape[1]):
|
||||
cam = cam + alpha[0][idx] * feature[0][idx]
|
||||
# 进行ReLU操作
|
||||
cam = np.maximum(cam.detach().numpy(), 0)
|
||||
|
||||
# plt.imshow(cam)
|
||||
# plt.colorbar()
|
||||
# plt.savefig("cam.jpg")
|
||||
|
||||
# 将cam区域放大到输入图片大小
|
||||
cam_ = cv2.resize(cam, (224, 224))
|
||||
cam_ = cam_ - np.min(cam_)
|
||||
cam_ = cam_ / np.max(cam_)
|
||||
# plt.imshow(cam_)
|
||||
# plt.savefig("cam_.jpg")
|
||||
cam = torch.from_numpy(cam)
|
||||
|
||||
return cam, cam_
|
||||
|
||||
def show_img(self, cam_, img, heatmap_save_pth, imgname):
|
||||
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
|
||||
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
|
||||
# cv2.imwrite("img.jpg", cam_img)
|
||||
cv2.imwrite(os.sep.join([heatmap_save_pth, imgname]), cam_img)
|
||||
|
||||
def get_hot_map(self, img_pth):
|
||||
img = Image.open(img_pth)
|
||||
img = img.resize((224, 224))
|
||||
input = self.process_img(img)
|
||||
grad_val, feature, input_grad = self.getGrad(input)
|
||||
cam, cam_ = self.getCam(grad_val, feature)
|
||||
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
|
||||
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
|
||||
cam_img = Image.fromarray(np.uint8(cam_img))
|
||||
return cam_img
|
||||
|
||||
# def __call__(self, img_root, heatmap_save_pth):
|
||||
# for imgname in os.listdir(img_root):
|
||||
# img = Image.open(os.sep.join([img_root, imgname]))
|
||||
# img = img.resize((224, 224))
|
||||
# # plt.imshow(img)
|
||||
# # plt.savefig("airplane.jpg")
|
||||
# input = self.process_img(img)
|
||||
# grad_val, feature, input_grad = self.getGrad(input)
|
||||
# cam, cam_ = self.getCam(grad_val, feature)
|
||||
# self.show_img(cam_, img, heatmap_save_pth, imgname)
|
||||
# return cam
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cam = cal_cam()
|
||||
img_root = "test_img/"
|
||||
heatmap_save_pth = "heatmap_result"
|
||||
cam(img_root, heatmap_save_pth)
|
Reference in New Issue
Block a user