访问 nn.Sequential 的类模块中的函数

Posted

技术标签:

【中文标题】访问 nn.Sequential 的类模块中的函数【英文标题】:Accessing functions in the class modules of nn.Sequential 【发布时间】:2021-02-12 19:21:20 【问题描述】:

在运行 nn.Sequential 时,我包含一个类模块列表(这将是神经网络的层)。运行 nn.Sequential 时,它调用模块的转发函数。然而,每个类模块也有一个我想在 nn.Sequential 运行时访问的函数。运行 nn.Sequential 时如何访问和运行该函数?

【问题讨论】:

【参考方案1】:

您可以为此使用 hook。让我们考虑以下在 VGG16 上演示的示例:

这是网络架构:

假设我们要监控 features Sequential(您在上面看到的 Conv2d 层)中第 (2) 层的输入和输出。 为此,我们注册了一个名为 my_hook 的前向钩子,它将在任何前向传递中被调用:

import torch
from torchvision.models import vgg16

def my_hook(self, input, output):
    print('my_hook\'s output')
    print('input: ', input)
    print('output: ', output)

# Sample net:
net = vgg16()

#Register forward hook:
net.features[2].register_forward_hook(my_hook)

# Test:
img = torch.randn(1,3,512,512)
out = net(img) # Will trigger my_hook and the data you are looking for will be printed

【讨论】:

以上是关于访问 nn.Sequential 的类模块中的函数的主要内容,如果未能解决你的问题,请参考以下文章

torch系列:torch中的nn.Sequential,nn.Concat/ConcatTable,nn.Parallel/PararelTable之间区别

深度学习之构造模型,访问模型参数——2020.3.11

一些函数

PyTorch 中的 nn.functional() 与 nn.sequential() 之间是不是存在计算效率差异

pytorch教程之nn.Sequential类详解——使用Sequential类来自定义顺序连接模型

pytorch教程之nn.Sequential类详解——使用Sequential类来自定义顺序连接模型