为啥 torch.nn.MultiheadAttention 中的 W_q 矩阵是二次的
Posted
技术标签:
【中文标题】为啥 torch.nn.MultiheadAttention 中的 W_q 矩阵是二次的【英文标题】:Why W_q matrix in torch.nn.MultiheadAttention is quadratic为什么 torch.nn.MultiheadAttention 中的 W_q 矩阵是二次的 【发布时间】:2020-12-18 19:09:31 【问题描述】:我正在尝试在我的网络中实现 nn.MultiheadAttention。根据docs,
embed_dim - 模型的总尺寸。
但是,根据source file,
embed_dim 必须能被 num_heads 整除
和
self.q_proj_weight = 参数(torch.Tensor(embed_dim, embed_dim))
如果我理解正确,这意味着每个头部只采用每个查询的一部分特征,因为矩阵是二次的。是认识的错误还是我的理解有误?
【问题讨论】:
【参考方案1】:每个头部使用投影查询向量的不同部分。您可以将其想象为查询被拆分为 num_heads
向量,这些向量独立用于计算缩放的点积注意力。因此,每个头都对查询中的特征(以及键和值)的不同线性组合进行操作。这种线性投影是使用self.q_proj_weight
矩阵完成的,投影查询被传递给F.multi_head_attention_forward
函数。
在F.multi_head_attention_forward
中,它是通过对查询向量进行整形和转置来实现的,这样就可以计算出各个头部的独立注意力efficiently by matrix multiplication。
注意力头大小是 PyTorch 的设计决定。理论上,你可以有不同的头部大小,所以投影矩阵的形状为embedding_dim
× num_heads * head_dims
。转换器的一些实现(例如基于 C++ 的 Marian 用于机器翻译,或 Huggingface's Transformers)允许这样做。
【讨论】:
谢谢!尽管如此,问题还是略有不同:在实现手电筒时,我只能看到形状矩阵的乘法 (embed_dim x embed_dim
),因此当我输入 Q 时,每个头部确实会使用 特征子集 i> 的矩阵 Q?
直觉上,你可以这么说;从技术上讲,每个注意力头都对特征的线性组合进行操作是正确的。以上是关于为啥 torch.nn.MultiheadAttention 中的 W_q 矩阵是二次的的主要内容,如果未能解决你的问题,请参考以下文章
为啥 DataGridView 上的 DoubleBuffered 属性默认为 false,为啥它受到保护?