multiheadattention-torch
Posted lixyuan
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了multiheadattention-torch相关的知识,希望对你有一定的参考价值。
multiheadattention
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
def forward(self, query, key, value, mask=None):
dk = query.size()[-1]
scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(dk)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention = F.softmax(scores, dim=-1)
return attention.matmul(value)
class MultiSelfAttention(nn.Module):
def __init__(self, heads, d_model, dropout = 0.1):
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads
self.h = heads
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention()
def forward(self, q, k, v, mask=None):
bs = q.size(0) #batch
# perform linear operation and split into N heads
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# transpose to get dimensions bs * N * sl * d_model
k = k.transpose(1,2)
q = q.transpose(1,2)
v = v.transpose(1,2)
# calculate attention using function we will define next
scores = self.attention(q,k,v)
# concatenate heads and put through final linear layer
concat = scores.transpose(1,2).contiguous() .view(bs, -1, self.d_model)
output = self.out(concat)
return output
以上是关于multiheadattention-torch的主要内容,如果未能解决你的问题,请参考以下文章