pytorch笔记:nn.MultiheadAttention
Posted UQI-LIUWJ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch笔记:nn.MultiheadAttention相关的知识,希望对你有一定的参考价值。
1 函数介绍
torch.nn.MultiheadAttention(
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None)
2 参数介绍
embed_dim | 模型的维度 |
num_heads | attention的头数 (embed_dim会平均分配给每个头,也即每个头的维度是embed_dim//num_heads) |
dropout | attn_output_weights的dropout概率 |
bias | input和output的投影函数,是否有bias |
kdim | k的维度,默认embed_dim |
vdim | v的维度,默认embed_dim |
batch_first | True——输入和输出的维度是(batch_num,seq_len,feature_dim) |
False——输入和输出的维度是(batch_num,seq_len,feature_dim) |
3 forward函数
forward(
query,
key,
value,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
average_attn_weights=True)
4 forward函数参数介绍
query |
|
key |
|
value |
|
key_padding_mask | 如果设置,那么
True表示对应的key value在计算attention的时候,需要被忽略 |
need_weights | 如果设置,那么返回值会多一个attn_output_weight |
attn_mask | True表示对应的attention value 不应该存在 |
average_attn_weights | 如果设置,那么返回的是各个头的平均attention weight 否则,就是把所有的head分别输出 |
5 forward输出
attn_output |
|
attn_output_weight |
|
6 举例
import torch
import torch.nn as nn
lst=torch.Tensor([[1,2,3,4],
[2,3,4,5],
[7,8,9,10]])
lst=lst.unsqueeze(1)
lst.shape
#torch.Size([3, 1, 4])
multi_atten=nn.MultiheadAttention(embed_dim=4,
num_heads=2)
multi_atten(lst,lst,lst)
'''
(tensor([[[ 1.9639, -3.7282, 2.1215, 0.6630]],
[[ 2.2423, -4.2444, 2.2466, 1.0711]],
[[ 2.3823, -4.5058, 2.3015, 1.2964]]], grad_fn=<AddBackward0>),
tensor([[[9.0335e-02, 1.2198e-01, 7.8769e-01],
[2.6198e-02, 4.4854e-02, 9.2895e-01],
[1.6031e-05, 9.4658e-05, 9.9989e-01]]], grad_fn=<DivBackward0>))
'''
以上是关于pytorch笔记:nn.MultiheadAttention的主要内容,如果未能解决你的问题,请参考以下文章