Pytorch框架学习---hook函数和CAM类激活图

Posted zpc1001

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch框架学习---hook函数和CAM类激活图相关的知识,希望对你有一定的参考价值。

本节简单总结Pytorch中hook函数,CAM算法生成注意力图【文中思维导图采用MindMaster软件
注意:对于真正运用CAM的代码,本人后续随着需要,再逐步更新。

1.hook函数

(1)定义

? 不改变主体(前向、后向传播等)情况下,实现额外的功能,如在backward之后,仍然可以得到特征图和非叶子节点的梯度,即便它们被释放。

(2)方法

技术图片

? 节省精力, 由于网上已经有人对这4和hook函数总结的很好,故在此引用,不再复写。

? 这里我们直接来举一个例子,使用hook函数可视化所有层的特征图,即调用上面的register_forward_hook获取网络层的输出:

# 注册hook
    fmap_dict = dict()
    for name, sub_module in alexnet.named_modules():  # 如果是named_children()则是返回Sequential本身features
        # print(sub_module)   # sub_module  Sequential本身features以及内部所有的网络层features.0

        if isinstance(sub_module, nn.Conv2d):
            key_name = str(sub_module.weight.shape)
            fmap_dict.setdefault(key_name, list())   # 构建字典中key value对

            n1, n2 = name.split(".")  # features.0,  为nn.Sequential

            def hook_func(module, i, o):
                key_name = str(module.weight.shape)
                fmap_dict[key_name].append(o)  # 索引名字,添加特征图
                # print("famp_dict:{}".format(fmap_dict))

            alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

    # forward
    output = alexnet(img_tensor)

    # add image
    for layer_name, fmap_list in fmap_dict.items():  # 返回一个可迭代的列表
        fmap = fmap_list[0]  # 把list中元素取出
        fmap.transpose_(0, 1)

        nrow = int(np.sqrt(fmap.shape[0]))
        fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
        writer.add_image(‘feature map in {}‘.format(layer_name), fmap_grid, global_step=0)

? 对每一个卷积层得到的特征图,作tensorboard可视化:

技术图片

? 注意:这里可视化卷积层,但是由于卷积层后面接的是激活函数relu,其中relu(inplace=True)原位操作,会对卷积层的输出做一定的改变。

2.CAM(Class Activation Map)类激活图

? 啥话先不说,直接上图!!!原来这个就是CAM算法出来的,当判别网络将图片归为“猫”这个类别时,红色代表网络注意的地方,蓝色则是没有注意的地方:

技术图片

(1)原始CAM

技术图片

? 最后一层卷积得到的特征图,经过全局平均池化GAP,得到对应神经元向量,全连接层的权重,即是CAM对特征图加权的权重,经过加权之后的特征图即是最终类似注意力的激活图。

? 局限性:最后必须是GAP,需要改动原始网络并重新训练,因而改进版Grad-CAM上线

(2)Grad-CAM(利用特征图的梯度,作为加权权重)

? 对特征图梯度做平均,得到n个特征图对应的n个平均梯度,将其作为CAM权重。

技术图片

? 实战代码如下参考:github,后续用到CAM时,再放入自己项目的激活图展示代码。


以上是关于Pytorch框架学习---hook函数和CAM类激活图的主要内容,如果未能解决你的问题,请参考以下文章

学习打卡05可解释机器学习笔记之CAM+Captum代码实战

学习打卡05可解释机器学习笔记之CAM+Captum代码实战

React函数类组件及其Hooks学习

Pytorch学习--编程实战:猫和狗二分类

CAM(类激活映射),卷积可视化,神经网络可视化,一个库搞定,真的简单的不能再简单

pytorch入门 01