Pytorch CIFAR10图像分类 Vision Transformer(ViT) 篇

Posted 风信子的猫Redamancy

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch CIFAR10图像分类 Vision Transformer(ViT) 篇相关的知识,希望对你有一定的参考价值。

Pytorch CIFAR10图像分类 Vision Transformer(ViT) 篇

文章目录


这里贴一下汇总篇: 汇总篇

4. 定义网络(ViT篇)

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

模型特点

ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  1. 数据集的原图像被划分为多个patch后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  2. 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  3. 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

其主要结构为多个Encoder和Decoder模块所组成,其中Encoder和Decoder的详细结构如下图[2]所示:

Encoder与Decoder由许多结构组成,如:多头注意力(Multi-Head Attention)层,Feed Forward层,Normaliztion层,甚至残差连接(Residual Connection,图中的“Add”)。不过,其中最重要的结构是多头注意力(Multi-Head Attention)结构,该结构基于自注意力(Self-Attention)机制,是多个Self-Attention的并行组成。

所以,理解了Self-Attention就抓住了Transformer的核心。

Attention模块

以下是Self-Attention的解释,其核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个 Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。

在Self-Attention中:

  1. 最初的输入向量首先会经过Embedding层映射成Q (Query), K K K (Key),V (Value) 三个向量,由于是并行操作,所以代码中是映射成为dim x x x 3 的向量然后进行分割,换言之,如果你的输入向量为一个向量序列 ( x 1 , x 2 , x 3 ) \\left(x_1 , x_2 , x_3\\right) (x1x2x3) ,其中的 x 1 , x 2 , x 3 x_1 , x_2 , x_3 x1x2x3 都是一维向量,那么每一个一维向量 都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为, Q , K , V Q , K , V QKV 三个 矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是 Q , K , V Q , K , V QKV 三个,主要是因为需要两个向量点乘以获得权重,又需要 另一个向量来承载权重向加的结果,所以,最少需要 3 个矩阵。
    q i = W q ⋅ x i k i = W k ⋅ x i , i = 1 , 2 , 3 … v i = W v ⋅ x i \\left\\\\beginarraylq_i=W_q \\cdot x_i \\\\k_i=W_k \\cdot x_i, \\quad i=1,2,3 \\ldots \\\\v_i=W_v \\cdot x_i\\endarray\\right. qi=Wqxiki=Wkxi,i=1,2,3vi=Wvxi

  2. 自注意力机制的自注意主要体现在它的Q,K,V都来源于其自身,也就是该过程是在提取输入的不同顺序的向量的联系与特征,最终通过不同顺序向量之间的联系紧密性(Q与K乘积经过Softmax的结果)来表现出来。Q,K,V得到后就需要获取向量间权重,需要对Q和K进行点乘并除以维度的平方根,对所有向量的结果进行Softmax处理,通过公式(2)的操作,我们获得了向量之间的关系权重。
    a 1 , 1 = q 1 ⋅ k 1 / d a 1 , 2 = q 1 ⋅ k 2 / d a 1 , 3 = q 1 ⋅ k 3 / d \\beginsplit\\begincases a_1,1 = q_1 \\cdot k_1 / \\sqrt d \\\\ a_1,2 = q_1 \\cdot k_2 / \\sqrt d \\\\ a_1,3 = q_1 \\cdot k_3 / \\sqrt d \\endcases \\endsplit a1,1=q1k1/d a1,2=q1k2/d a1,3=q1k3/d

    S o f t m a x : a ^ 1 , i = e x p ( a 1 , i ) / ∑ j e x p ( a 1 , j ) , j = 1 , 2 , 3 … (3) Softmax: \\hat a_1,i = exp(a_1,i) / \\sum_j exp(a_1,j),\\hspace1em j = 1,2,3 \\ldots \\tag3 Softmax:a^1,i=exp(a1,i)/jexp(a1,j),j=1,2,3(3)

  3. 其最终输出则是通过V这个映射后的向量与Q,K经过Softmax结果进行weight sum获得,这个过程可以理解为在全局上进行自注意表示。每一组Q,K,V最后都有一个V输出,这是Self-Attention得到的最终结果,是当前向量在结合了它与其他向量关联权重后得到的结果。
    b 1 = ∑ i a ^ 1 , i v i , i = 1 , 2 , 3... (4) b_1 = \\sum_i \\hat a_1,iv_i,\\hspace1em i = 1,2,3... \\tag4 b1=ia^1,ivi,i=1,2,3...(4)
    通过下图可以整体把握Self-Attention的全部过程。

    多头注意力机制就是将原本self-Attention处理的向量分割为多个Head进行处理,这一点也可以从代码中体现,这也是attention结构可以进行并行加速的一个方面。

    总结来说,多头注意力机制在保持参数总量不变的情况下,将同样的query, key和value映射到原来的高维空间(Q,K,V)的不同子空间 ( Q 0 , K 0 , V 0 ) (Q_0,K_0,V_0) (Q0,K0,V0)中进行自注意力的计算,最后再合并不同子空间中的注意力信息。

    所以,对于同一个输入向量,多个注意力机制可以同时对其进行处理,即利用并行计算加速处理过程,又在处理的时候更充分的分析和利用了向量特征。下图展示了多头注意力机制,其并行能力的主要体现在下图中的a1和a2是同一个向量进行分割获得的。

class Attention(nn.Module):
    '''
    Attention Module used to perform self-attention operation allowing the model to attend
    information from different representation subspaces on an input sequence of embeddings.
    The sequence of operations is as follows :-

    Input -> Query, Key, Value -> ReshapeHeads -> Query.TransposedKey -> Softmax -> Dropout
    -> AttentionScores.Value -> ReshapeHeadsBack -> Output
    '''
    def __init__(self, 
                 embed_dim, # 输入token的dim
                 heads=8, 
                 activation=None, 
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(以上是关于Pytorch CIFAR10图像分类 Vision Transformer(ViT) 篇的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch CIFAR10图像分类 ResNeXt篇

Pytorch CIFAR10图像分类 ResNeXt篇

Pytorch CIFAR10图像分类 ResNet篇

Pytorch CIFAR10图像分类 EfficientNet v1篇

Pytorch CIFAR10图像分类 EfficientNet v1篇

Pytorch CIFAR10图像分类 EfficientNet v1篇