Pytorch可视化热力图
Posted 谷小雨
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch可视化热力图相关的知识,希望对你有一定的参考价值。
可视化热力图可以有两种方式:
1)特征图可视化,将各通道特征的最大值作为热力图像素值,进行可视化——可以参考博客,一种比较灵活的特征图保存方式
2)根据梯度值结合特征图计算热力图,热力图的显示的重点是梯度高的地方,也是网络关注的重点
基于梯度进行热力图可视化有一些工作,如grad-cam,也有一些开发好的脚本,不过这些脚本不具有通用性,
因此此处基于torch的hook机制进行可视化,是一种基础并且通用性很好的策略,很容易在自己的模式上进行尝试。
代码的逻辑结构如下:
首先定义模型,加载权重,再对想要可视化的网络层进行hook注册,接下来推理模型并进行梯度反传即可,
farward_hook,backward_hook会自动获取对应的特征图和反传梯度,后面处理并保存到本地
# coding: utf-8 import cv2 import os import torch import numpy as np def img_preprocess(img_in): pass # get activate map def backward_hook(module, grad_in, grad_out): grad_block.append(grad_out[0].detach()) # get gradient map def farward_hook(module, input, output): fmap_block.append(output) # apply color to heatmap and save the result def show_cam_on_image(img, mask, out_dir): h, w, _ = img.shape heatmap = cv2.resize(heatmap, (w, h)) heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 img = np.float32(img) / 255 # make sure pixel value will not be bigger than 256 after add cam = heatmap + np.float32(img) # show heatmap in original image cam = cam / np.max(cam) path_cam_img = os.path.join(out_dir, "cam.jpg") path_raw_img = os.path.join(out_dir, "raw.jpg") if not os.path.exists(out_dir): os.makedirs(out_dir) cv2.imwrite(path_cam_img, np.uint8(255 * cam)) cv2.imwrite(path_raw_img, np.uint8(255 * img)) def gen_cam(feature_map, grads): """ 依据梯度和特征图,生成cam :param feature_map: np.array, in [C, H, W] :param grads: np.array, in [C, H, W] :return: np.array, [H, W] """ cam = np.zeros(feature_map.shape[1:], dtype=np.float32) # cam shape (H, W) weights = np.mean(grads, axis=(1, 2)) # for i, w in enumerate(weights): cam += w * feature_map[i, :, :] cam = np.maximum(cam, 0) cam = cv2.resize(cam, (32, 32)) cam -= np.min(cam) cam /= np.max(cam) return cam if __name__ == \'__main__\': BASE_DIR = os.path.dirname(os.path.abspath(__file__)) path_img, path_net, output_dir = None, None, None # change to yours fmap_block = list() grad_block = list() # 图片读取;网络加载 img = cv2.imread(path_img, 1) # H*W*C img_input = img_preprocess(img) model = ResNet50() model.load_state_dict(torch.load(path_net)) # 注册 hook model.layer4.register_forward_hook(farward_hook) # get activate map model.layer4.register_backward_hook(backward_hook) # get gradient map # forward output = model(img_input) # model.training is True # backward model.zero_grad() loss = model.get_loss(output) loss.backward() # 生成cam grads_val = grad_block[0].cpu().data.numpy().squeeze() fmap = fmap_block[0].cpu().data.numpy().squeeze() cam = gen_cam(fmap, grads_val) # 保存cam图片 show_cam_on_image(img, cam, output_dir)
——主体参考代码
R语言ggplot2可视化:使用热力图可视化dataframe数据自定义设置热力图的颜色自定添加标题轴标签热力图线框等
R语言ggplot2可视化:使用热力图可视化dataframe数据、自定义设置热力图的颜色、自定添加标题、轴标签、热力图线框等
目录
以上是关于Pytorch可视化热力图的主要内容,如果未能解决你的问题,请参考以下文章
R语言ggplot2可视化:使用热力图可视化dataframe数据自定义设置热力图的颜色自定添加标题轴标签热力图线框等
数据可视化Python 热力图(seaborn.heatmap)
Python使用matplotlib可视化相关性分析热力图图heatmap使用seaborn中的heatmap函数可视化相关性热力图(Correllogram)
Python使用matplotlib可视化时间序列日历热力图日历热力图可以很好地描绘极端值和节日数据特性(Calendar Heatmap)