访问 nn.Sequential 的类模块中的函数
Posted
技术标签:
【中文标题】访问 nn.Sequential 的类模块中的函数【英文标题】:Accessing functions in the class modules of nn.Sequential 【发布时间】:2021-02-12 19:21:20 【问题描述】:在运行 nn.Sequential 时,我包含一个类模块列表(这将是神经网络的层)。运行 nn.Sequential 时,它调用模块的转发函数。然而,每个类模块也有一个我想在 nn.Sequential 运行时访问的函数。运行 nn.Sequential 时如何访问和运行该函数?
【问题讨论】:
【参考方案1】:您可以为此使用 hook。让我们考虑以下在 VGG16 上演示的示例:
这是网络架构:
假设我们要监控 features Sequential
(您在上面看到的 Conv2d 层)中第 (2) 层的输入和输出。
为此,我们注册了一个名为 my_hook
的前向钩子,它将在任何前向传递中被调用:
import torch
from torchvision.models import vgg16
def my_hook(self, input, output):
print('my_hook\'s output')
print('input: ', input)
print('output: ', output)
# Sample net:
net = vgg16()
#Register forward hook:
net.features[2].register_forward_hook(my_hook)
# Test:
img = torch.randn(1,3,512,512)
out = net(img) # Will trigger my_hook and the data you are looking for will be printed
【讨论】:
以上是关于访问 nn.Sequential 的类模块中的函数的主要内容,如果未能解决你的问题,请参考以下文章
torch系列:torch中的nn.Sequential,nn.Concat/ConcatTable,nn.Parallel/PararelTable之间区别
PyTorch 中的 nn.functional() 与 nn.sequential() 之间是不是存在计算效率差异