multiheadattention-torch

Posted lixyuan

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了multiheadattention-torch相关的知识,希望对你有一定的参考价值。

multiheadattention

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class ScaledDotProductAttention(nn.Module):

    def forward(self, query, key, value, mask=None):
        dk = query.size()[-1]
        scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(dk)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention = F.softmax(scores, dim=-1)
        return attention.matmul(value)

class MultiSelfAttention(nn.Module):

    def __init__(self, heads, d_model, dropout = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
        self.attention = ScaledDotProductAttention()
    
    def forward(self, q, k, v, mask=None):
        
        bs = q.size(0) #batch
        
        # perform linear operation and split into N heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        
        # transpose to get dimensions bs * N * sl * d_model
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        
        # calculate attention using function we will define next
        scores = self.attention(q,k,v)
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous()        .view(bs, -1, self.d_model)
        output = self.out(concat)
    
        return output

以上是关于multiheadattention-torch的主要内容,如果未能解决你的问题,请参考以下文章

R留学生作业代码代写代编程代编程代编程

IPEX-1代/3代/4代/5代,PCB天线底座,公头,样式及封装尺寸图

JVM 年轻代 老年代 持久代 gc

深圳本地网店代运营公司

C线程代业代写代调试POSIX Threads代编码

jvm中的年轻代 老年代 持久代 gc