MultiHeadAttnetion 中的 att_mask 和 key_padding_mask 有啥区别

Posted

技术标签:

【中文标题】MultiHeadAttnetion 中的 att_mask 和 key_padding_mask 有啥区别【英文标题】:what the difference between att_mask and key_padding_mask in MultiHeadAttnetionMultiHeadAttnetion 中的 att_mask 和 key_padding_mask 有什么区别 【发布时间】:2020-10-19 02:46:32 【问题描述】:

pytorch的MultiHeadAttnetionatt_maskkey_padding_mask有什么区别:

key_padding_mask – 如果提供,key 中指定的填充元素将被注意力忽略。当给定一个二进制掩码并且值为 True 时,注意力层上的相应值将被忽略。当给定一个字节掩码并且一个值不为零时,注意力层上的相应值将被忽略

attn_mask – 2D 或 3D 掩码,可防止对某些位置的注意力。将为所有批次广播 2D 掩码,而 3D 掩码允许为每个批次的条目指定不同的掩码。

提前致谢。

【问题讨论】:

【参考方案1】:

key_padding_mask 用于屏蔽正在填充的位置,即在输入序列结束之后。这始终特定于输入批次,并且取决于批次中的序列与最长的序列相比有多长。它是一个形状为 batch size × input length 的 2D 张量。

另一方面,attn_mask 表示哪些键值对是有效的。在 Transformer 解码器中,三角形掩码用于模拟推理时间并防止关注“未来”位置。这就是att_mask 通常的用途。如果是二维张量,形状为输入长度 × 输入长度。您还可以有一个特定于批次中每个项目的掩码。在这种情况下,您可以使用形状为 (batch size × num Heads) × input length × input length 的 3D 张量。 (因此,理论上,您可以用 3D att_mask 模拟 key_padding_mask。)

【讨论】:

拥有一个特定于批次中每个项目的掩码的目的是什么?好奇。 批量中每个项目的差异位置可能存在填充。例如如果输入是一系列句子,并且它们在开头或结尾被填充,我们需要为每个句子应用一个单独的掩码。在解码器的情况下,此掩码将是 attn_mask 和 key_padding_mask 的组合(指键、值的编码器输入) 在为批次中的每个项目传递掩码时,模块是否对每个注意力头使用沿 0 维度的顺序项目?即当batch_size=32num_heads=4 时,att_mask[:4,:,:] 是项目 1 的掩码(用于头部 1、2、3 和 4)?【参考方案2】:

我认为它们的工作方式相同:两个掩码都定义了查询和键之间的注意不会被使用。而这两种选择的唯一区别就是在哪个形状上输入面具更舒服

根据代码,这两个掩码似乎被合并/合并,所以它们都扮演相同的角色——不会使用查询和键之间的注意力。因为它们是联合的:如果您需要使用两个掩码,则两个掩码输入可以是不同的值,或者您可以根据需要的形状方便地在任何 mask_args 中输入掩码:这是原始代码的一部分pytorch/functional.py 在函数 multi_head_attention_forward() 中的第 5227 行附近

...
# merge key padding and attention masks
if key_padding_mask is not None:
    assert key_padding_mask.shape == (bsz, src_len), \
        f"expecting key_padding_mask shape of (bsz, src_len), but got key_padding_mask.shape"
    key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
        expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
    if attn_mask is None:
        attn_mask = key_padding_mask
    elif attn_mask.dtype == torch.bool:
        attn_mask = attn_mask.logical_or(key_padding_mask)
    else:
        attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
...
# so here only the merged/unioned mask is used to actually compute the attention
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)

如果你有不同的意见或我错了,请纠正我。

【讨论】:

以上是关于MultiHeadAttnetion 中的 att_mask 和 key_padding_mask 有啥区别的主要内容,如果未能解决你的问题,请参考以下文章

jquery 取id模糊查询

初识att&ack框架

Amazon Route 53 DNS 反向查找区域 - ATT IP 块设置

Keras 中的多个输出

在这种情况下,jmp 指令如何在 att 汇编中工作

左手攻击(ATT&CK),右手防御(Shield)|230种招式教你花式应对黑客攻击