Pytorch框架学习---hook函数和CAM类激活图
Posted zpc1001
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch框架学习---hook函数和CAM类激活图相关的知识,希望对你有一定的参考价值。
本节简单总结Pytorch中hook函数,CAM算法生成注意力图【文中思维导图采用MindMaster软件】 |
1.hook函数
(1)定义
? 不改变主体(前向、后向传播等)情况下,实现额外的功能,如在backward之后,仍然可以得到特征图和非叶子节点的梯度,即便它们被释放。
(2)方法
? 节省精力, 由于网上已经有人对这4和hook函数总结的很好,故在此引用,不再复写。
? 这里我们直接来举一个例子,使用hook函数可视化所有层的特征图,即调用上面的register_forward_hook获取网络层的输出:
# 注册hook
fmap_dict = dict()
for name, sub_module in alexnet.named_modules(): # 如果是named_children()则是返回Sequential本身features
# print(sub_module) # sub_module Sequential本身features以及内部所有的网络层features.0
if isinstance(sub_module, nn.Conv2d):
key_name = str(sub_module.weight.shape)
fmap_dict.setdefault(key_name, list()) # 构建字典中key value对
n1, n2 = name.split(".") # features.0, 为nn.Sequential
def hook_func(module, i, o):
key_name = str(module.weight.shape)
fmap_dict[key_name].append(o) # 索引名字,添加特征图
# print("famp_dict:{}".format(fmap_dict))
alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)
# forward
output = alexnet(img_tensor)
# add image
for layer_name, fmap_list in fmap_dict.items(): # 返回一个可迭代的列表
fmap = fmap_list[0] # 把list中元素取出
fmap.transpose_(0, 1)
nrow = int(np.sqrt(fmap.shape[0]))
fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
writer.add_image(‘feature map in {}‘.format(layer_name), fmap_grid, global_step=0)
? 对每一个卷积层得到的特征图,作tensorboard可视化:
? 注意:这里可视化卷积层,但是由于卷积层后面接的是激活函数relu,其中relu(inplace=True)原位操作,会对卷积层的输出做一定的改变。
2.CAM(Class Activation Map)类激活图
? 啥话先不说,直接上图!!!原来这个就是CAM算法出来的,当判别网络将图片归为“猫”这个类别时,红色代表网络注意的地方,蓝色则是没有注意的地方:
(1)原始CAM
? 最后一层卷积得到的特征图,经过全局平均池化GAP,得到对应神经元向量,全连接层的权重,即是CAM对特征图加权的权重,经过加权之后的特征图即是最终类似注意力的激活图。
? 局限性:最后必须是GAP,需要改动原始网络并重新训练,因而改进版Grad-CAM上线。
(2)Grad-CAM(利用特征图的梯度,作为加权权重)
? 对特征图梯度做平均,得到n个特征图对应的n个平均梯度,将其作为CAM权重。
? 实战代码如下参考:github,后续用到CAM时,再放入自己项目的激活图展示代码。
以上是关于Pytorch框架学习---hook函数和CAM类激活图的主要内容,如果未能解决你的问题,请参考以下文章
学习打卡05可解释机器学习笔记之CAM+Captum代码实战
学习打卡05可解释机器学习笔记之CAM+Captum代码实战