如何在pytorch中为nn.Transformer编写一个前向钩子函数?

Posted

技术标签:

【中文标题】如何在pytorch中为nn.Transformer编写一个前向钩子函数?【英文标题】:How to write a forward hook function for nn.Transformer in pytorch? 【发布时间】:2021-11-04 03:12:42 【问题描述】:

我了解到前向钩子函数的形式为hook_fn(m,x,y)。 m 指模型,x 指输入,y 指输出。我想为nn.Transformer 写一个前向钩子函数。 但是,变压器层需要输入 src 和 tgt。例如,>>> out = transformer_model(src, tgt)。那么如何区分这些输入呢?

【问题讨论】:

请提供足够的代码,以便其他人更好地理解或重现问题。 【参考方案1】:

你的钩子会用 tuples 为xy 调用你的回调函数。正如torch.nn.Module.register_forward_hook 的文档页面中所述(它确实很好地解释了xy 的类型)。

输入只包含给模块的位置参数。 关键字参数不会传递给钩子,而只会传递给 向前。 [...]。

model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)

定义你的回调:

def hook(module, x, y):
    print(f'is tuple=isinstance(x, tuple) - length=len(x)')      
    src, tgt = x
  
    print(f'src: src.shape')
    print(f'tgt: tgt.shape')

连接到您的nn.Module

>>> model.register_forward_hook(hook)

做一个推理:

>>> out = model(src, tgt)
is tuple=True - length=2
src: torch.Size([10, 32, 512])
tgt: torch.Size([20, 32, 512])

【讨论】:

爱你,兄弟!我明白了。

以上是关于如何在pytorch中为nn.Transformer编写一个前向钩子函数?的主要内容,如果未能解决你的问题,请参考以下文章

如何在pytorch中为nn.Transformer编写一个前向钩子函数?

在 pytorch 中为 CNN 设置自定义内核

在 pytorch 中为聊天机器人加载经过训练的模型保存

如何在 PyTorch 中对子集使用不同的数据增强

在 Pytorch 中为 HDF5 文件创建数据集和数据加载器时出现问题:没有足够的值来解包(预期 2,得到 1)

深度学习pytorch编写新的层