代码实现 加性注意力 | additive attention #51CTO博主之星评选#

Posted LolitaAnn

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了代码实现 加性注意力 | additive attention #51CTO博主之星评选#相关的知识,希望对你有一定的参考价值。

import math
import torch
from torch import nn
from d2l import torch as d2l

python人必懂的导包,这不用解释了。

def masked_softmax(X, valid_lens):
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

一个遮蔽softmax的操作。在nadaraya-waston核回归代码实现中我们做过一个类似的mask操作。就是倒数第三段代码那个位置,每个$x$和除自己本身以外的其他$x_i$进行计算,然后我们使用X_tile[(1 - torch.eye(n_train)).type(torch.bool)]将其本身遮盖掉了。也就是mask操作。

这个函数的功能是这样的:就是我们传入的一整个张量可能只有一部分是有用的,所以将没用的部分mask掉,只对剩下的部分进行softmax计算。比如我们传入一个长度为5的向量,我们仅需要前两个数据,那经过这个函数之后,后三个数加起来是0,前两个数加起来是1。

  • 函数两个参数Xvalid_lens,x是要softmax的张量,valid_lens存储每个维度上的有效长度,不管传入一维还是二维,都要确保能进行广播机制。
  • 函数一进来是一个if语句if valid_lens is None是说如果没有给出valid_lens,也就是整个张量都是有效的,不需要进行mask之后再softmax,所以if语句直接返回一个普通的softmax操作,函数运行结束。
  • 当传入valid_lens的时候进入else

    • 首先是用shape存储待mask的张量X的shape。
    • 又是一个if-else语句,这个是用来处理valid_lens长度的,将valid_lens长度转化矩阵的行数。

      • valid_lens是一维的时候进入if,将其转换为一个mask向量。解释一下,因为mini-batch的存在,所以传入的X一般是三维的,第一个维度是batch size,二三维度上的才是矩阵的大小。之前用shape存储X的shape,现在用shape[1]取到X中的矩阵是几行,然后每行的有效元素对应valid_lens中的数值。

        想了解torch.repeat_interleave看这里→pytorch中的repeat操作对比

      • valid_lens不是一维的时候进入else中。直接将其从一个矩阵转化为一个向量即可。
      • 对于mask操作是直接用d2l中的函数实现的,源码我就不去扒了,对于维度的处理记住:
        • 如果传入的valid_lens是一维的,那valid_lens的长度要和X的第二维(shape[1])一样。
        • 如果传入的valid_lens是二维的,那valid_lens的第一维度要和batch size一样,第二维度要和X中矩阵的行数一样。
        • 具体例子可以看代码实现 缩放点积注意力 | scaled dot-product attention
class AdditiveAttention(nn.Module):
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

加性注意力代码部分:

因为这里涉及到一个升到四维张量,所以一定要自己捋一捋。

  • 主要的三个参数,key_sizekeys的长度, query_sizequery的长度, num_hiddens隐藏层的大小。因为加性注意力是处理keys和queries长度不一样的情况。
  • 三个小的线性层。self.W_kself.W_q是把key和query转化到隐藏层,self.W_v是从隐藏层到单个输出。
  • 在这里均设置不需要bias
  • 最后还做了一下dropout
  • 然后是前向传播函数,是计算$a(\\mathbf q, \\mathbf k) = \\mathbf w_v^\\top \\texttanh(\\mathbf W_q\\mathbf q + \\mathbf W_k \\mathbf k)$的过程:

    • 将queries和keys扔进前边两个线性层就可以得到queries和keys,进行维度调整。

      queries 的形状:(batch_size, 查询的个数, 1, num_hidden)

      key 的形状:(batch_size, 1, “键-值”对的个数, num_hiddens)

    • 进行公式的计算。
    • scores的计算是self.w_v 仅有一个输出,因此从形状中移除最后那个维度。
      scores 的形状:(batch_size, 查询的个数, “键-值”对的个数)
    • 最后values 的形状:(batch_size, “键-值”对的个数, 值的维度)
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# `values` 的小批量数据集中,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)

带入一个样例测试一下子。

注意这里使用到.eval(),是不启用 BatchNormalization 和 Dropout。

d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel=Keys, ylabel=Queries)

因为和代码实现 缩放点积注意力 | scaled dot-product attention用的数据都一样的,所以就不具体解析这个热图了,不懂的可以看点积缩放注意力那篇文章的热图分析。

以上是关于代码实现 加性注意力 | additive attention #51CTO博主之星评选#的主要内容,如果未能解决你的问题,请参考以下文章

R语言mgcv包中的gam函数拟合广义加性模型:线性回归与广义加性模型GAMs(Generalized Additive Model)模型性能比较(比较RMSE比较R方指标)

R语言使用mgcv包中的gam函数拟合广义加性模型(Generalized Additive Model,GAMs):从广义加性模型GAM中抽取学习到的样条函数(spline function)

R语言广义加性模型(GAMs:Generalized Additive Model)建模:数据加载划分数据并分别构建线性回归模型和广义线性加性模型GAMs并比较线性模型和GAMs模型的性能

R语言mgcv包中的gam函数拟合广义加性模型(Generalized Additive Model)GAM(对非线性变量进行样条处理计算RMSER方调整R方可视化模型预测值与真实值的曲线)

log4j可加性,类别日志记录级别和追加者阈值

如何在记录器 xml 标记内的 appender refs 中禁用 log4j2.xml 可加性