代码实现 加性注意力 | additive attention #51CTO博主之星评选#
Posted LolitaAnn
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了代码实现 加性注意力 | additive attention #51CTO博主之星评选#相关的知识,希望对你有一定的参考价值。
import math
import torch
from torch import nn
from d2l import torch as d2l
python人必懂的导包,这不用解释了。
def masked_softmax(X, valid_lens):
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
一个遮蔽softmax的操作。在nadaraya-waston核回归代码实现中我们做过一个类似的mask操作。就是倒数第三段代码那个位置,每个$x$和除自己本身以外的其他$x_i$进行计算,然后我们使用X_tile[(1 - torch.eye(n_train)).type(torch.bool)]
将其本身遮盖掉了。也就是mask操作。
这个函数的功能是这样的:就是我们传入的一整个张量可能只有一部分是有用的,所以将没用的部分mask掉,只对剩下的部分进行softmax计算。比如我们传入一个长度为5的向量,我们仅需要前两个数据,那经过这个函数之后,后三个数加起来是0,前两个数加起来是1。
- 函数两个参数
X
和valid_lens
,x是要softmax的张量,valid_lens存储每个维度上的有效长度,不管传入一维还是二维,都要确保能进行广播机制。 - 函数一进来是一个if语句
if valid_lens is None
是说如果没有给出valid_lens
,也就是整个张量都是有效的,不需要进行mask之后再softmax,所以if语句直接返回一个普通的softmax操作,函数运行结束。 -
当传入
valid_lens
的时候进入else- 首先是用
shape
存储待mask的张量X
的shape。 -
又是一个if-else语句,这个是用来处理
valid_lens
长度的,将valid_lens
长度转化矩阵的行数。-
当
valid_lens
是一维的时候进入if,将其转换为一个mask向量。解释一下,因为mini-batch的存在,所以传入的X
一般是三维的,第一个维度是batch size,二三维度上的才是矩阵的大小。之前用shape
存储X
的shape,现在用shape[1]
取到X
中的矩阵是几行,然后每行的有效元素对应valid_lens
中的数值。想了解
torch.repeat_interleave
看这里→pytorch中的repeat操作对比 - 当
valid_lens
不是一维的时候进入else中。直接将其从一个矩阵转化为一个向量即可。 - 对于mask操作是直接用d2l中的函数实现的,源码我就不去扒了,对于维度的处理记住:
- 如果传入的
valid_lens
是一维的,那valid_lens
的长度要和X
的第二维(shape[1]
)一样。 - 如果传入的
valid_lens
是二维的,那valid_lens
的第一维度要和batch size一样,第二维度要和X
中矩阵的行数一样。 - 具体例子可以看代码实现 缩放点积注意力 | scaled dot-product attention
- 如果传入的
-
- 首先是用
class AdditiveAttention(nn.Module):
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
加性注意力代码部分:
因为这里涉及到一个升到四维张量,所以一定要自己捋一捋。
- 主要的三个参数,
key_size
keys的长度,query_size
query的长度,num_hiddens
隐藏层的大小。因为加性注意力是处理keys和queries长度不一样的情况。 - 三个小的线性层。
self.W_k
和self.W_q
是把key和query转化到隐藏层,self.W_v
是从隐藏层到单个输出。 - 在这里均设置不需要bias
- 最后还做了一下dropout
-
然后是前向传播函数,是计算$a(\\mathbf q, \\mathbf k) = \\mathbf w_v^\\top \\texttanh(\\mathbf W_q\\mathbf q + \\mathbf W_k \\mathbf k)$的过程:
-
将queries和keys扔进前边两个线性层就可以得到queries和keys,进行维度调整。
queries
的形状:(batch_size
, 查询的个数, 1,num_hidden
)key
的形状:(batch_size
, 1, “键-值”对的个数,num_hiddens
) - 进行公式的计算。
scores
的计算是self.w_v
仅有一个输出,因此从形状中移除最后那个维度。scores
的形状:(batch_size
, 查询的个数, “键-值”对的个数)- 最后
values
的形状:(batch_size
, “键-值”对的个数, 值的维度)
-
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# `values` 的小批量数据集中,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
2, 1, 1)
valid_lens = torch.tensor([2, 6])
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)
带入一个样例测试一下子。
注意这里使用到.eval()
,是不启用 BatchNormalization 和 Dropout。
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel=Keys, ylabel=Queries)
因为和代码实现 缩放点积注意力 | scaled dot-product attention用的数据都一样的,所以就不具体解析这个热图了,不懂的可以看点积缩放注意力那篇文章的热图分析。
以上是关于代码实现 加性注意力 | additive attention #51CTO博主之星评选#的主要内容,如果未能解决你的问题,请参考以下文章
R语言mgcv包中的gam函数拟合广义加性模型:线性回归与广义加性模型GAMs(Generalized Additive Model)模型性能比较(比较RMSE比较R方指标)
R语言使用mgcv包中的gam函数拟合广义加性模型(Generalized Additive Model,GAMs):从广义加性模型GAM中抽取学习到的样条函数(spline function)
R语言广义加性模型(GAMs:Generalized Additive Model)建模:数据加载划分数据并分别构建线性回归模型和广义线性加性模型GAMs并比较线性模型和GAMs模型的性能
R语言mgcv包中的gam函数拟合广义加性模型(Generalized Additive Model)GAM(对非线性变量进行样条处理计算RMSER方调整R方可视化模型预测值与真实值的曲线)