如何在pytorch中为nn.Transformer编写一个前向钩子函数?
Posted
技术标签:
【中文标题】如何在pytorch中为nn.Transformer编写一个前向钩子函数?【英文标题】:How to write a forward hook function for nn.Transformer in pytorch? 【发布时间】:2021-11-04 03:12:42 【问题描述】:我了解到前向钩子函数的形式为hook_fn(m,x,y)
。 m 指模型,x 指输入,y 指输出。我想为nn.Transformer
写一个前向钩子函数。
但是,变压器层需要输入 src 和 tgt。例如,>>> out = transformer_model(src, tgt)
。那么如何区分这些输入呢?
【问题讨论】:
请提供足够的代码,以便其他人更好地理解或重现问题。 【参考方案1】:你的钩子会用 tuples 为x
和y
调用你的回调函数。正如torch.nn.Module.register_forward_hook
的文档页面中所述(它确实很好地解释了x
和y
的类型)。
输入只包含给模块的位置参数。 关键字参数不会传递给钩子,而只会传递给 向前。 [...]。
model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
定义你的回调:
def hook(module, x, y):
print(f'is tuple=isinstance(x, tuple) - length=len(x)')
src, tgt = x
print(f'src: src.shape')
print(f'tgt: tgt.shape')
连接到您的nn.Module
:
>>> model.register_forward_hook(hook)
做一个推理:
>>> out = model(src, tgt)
is tuple=True - length=2
src: torch.Size([10, 32, 512])
tgt: torch.Size([20, 32, 512])
【讨论】:
爱你,兄弟!我明白了。以上是关于如何在pytorch中为nn.Transformer编写一个前向钩子函数?的主要内容,如果未能解决你的问题,请参考以下文章
如何在pytorch中为nn.Transformer编写一个前向钩子函数?