基于torch函数TransformerEncoder出现AssertionError问题的解决

Posted 小乖乖的臭坏坏

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于torch函数TransformerEncoder出现AssertionError问题的解决相关的知识,希望对你有一定的参考价值。

在使用transformer model时,由于存在encoder-decoder,encoder-only,decoder-only三种结构以应对不同的task。当我们使用encoder-only时,必然会涉及到TransformerEncoder和TransformerEncoderLayer函数的调用。
那么如下代码出现了AssertionError问题,应当如何解决?
在这里插入图片描述
为什么会出现AssertionError(声明/断言)问题呢?可以看到,输入模型的第三维应该对应d_model这个参数,那么此处,这两个值应该一致。
在这里插入图片描述
修改以后:
在这里插入图片描述
运行得到:
在这里插入图片描述
其实我们发现,就transformer的编码器而言,输入输出的尺寸是一样的。

作于:
20215-9
21:50

以上是关于基于torch函数TransformerEncoder出现AssertionError问题的解决的主要内容,如果未能解决你的问题,请参考以下文章

推荐模型复现:熟悉Torch-RecHub框架与使用

Pytorch中torch.unsqueeze()和torch.squeeze()函数解析

torch.stack() 和 torch.cat() 函数有啥区别?

如何用其他 pytorch 函数替换 torch.sparse?

从图像角度理解torch.mean()函数。继而学习torch.max等等相关函数

如何用其他 pytorch 函数替换 torch.norm?