Swin-Transformer网络结构详解

Posted 太阳花的小绿豆

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Swin-Transformer网络结构详解相关的知识,希望对你有一定的参考价值。

文章目录


0 前言

Swin Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得ICCV 2021 best paper的荣誉称号。Swin Transformer网络是Transformer模型在视觉领域的又一次碰撞。该论文一经发表就已在多项视觉任务中霸榜。该论文是在2021年3月发表的,现在是2021年11月了,根据官方提供的信息可以看到,现在还在COCO数据集的目标检测以及实例分割任务中是第一名(见下图State of the Art表示第一)。

论文名称:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
原论文地址: https://arxiv.org/abs/2103.14030
官方开源代码地址:https://github.com/microsoft/Swin-Transformer

Pytorch实现代码: pytorch_classification/swin_transformer
Tensorflow2实现代码:tensorflow_classification/swin_transformer

不想看文章的可以看下我在bilibili上讲的视频: https://www.bilibili.com/video/BV1pL4y1v7jC


1 网络整体框架

在正文开始之前,先来简单对比下Swin Transformer和之前的Vision Transformer(如果不了解Vision Transformer的建议先去看下我之前的文章)。下图是Swin Transformer文章中给出的图1,左边是本文要讲的Swin Transformer,右边边是之前讲的Vision Transformer。通过对比至少可以看出两点不同:

  • Swin Transformer使用了类似卷积神经网络中的层次化构建方法(Hierarchical feature maps),比如特征图尺寸中有对图像下采样4倍的,8倍的以及16倍的,这样的backbone有助于在此基础上构建目标检测,实例分割等任务。而在之前的Vision Transformer中是一开始就直接下采样16倍,后面的特征图也是维持这个下采样率不变。
  • 在Swin Transformer中使用了Windows Multi-Head Self-Attention(W-MSA)的概念,比如在下图的4倍下采样和8倍下采样中,将特征图划分成了多个不相交的区域(Window),并且Multi-Head Self-Attention只在每个窗口(Window)内进行。相对于Vision Transformer中直接对整个(Global)特征图进行Multi-Head Self-Attention,这样做的目的是能够减少计算量的,尤其是在浅层特征图很大的时候。这样做虽然减少了计算量但也会隔绝不同窗口之间的信息传递,所以在论文中作者又提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通过此方法能够让信息在相邻的窗口中进行传递,后面会细讲。

接下来,简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图。通过图(a)可以看出整个框架的基本流程如下:

  • 首先将图片输入到Patch Partition模块中进行分块,即每4x4相邻的像素为一个Patch,然后在channel方向展平(flatten)。假设输入的是RGB三通道图片,那么每个patch就有4x4=16个像素,然后每个像素有R、G、B三个值所以展平后是16x3=48,所以通过Patch Partition后图像shape由 [H, W, 3]变成了 [H/4, W/4, 48]。然后在通过Linear Embeding层对每个像素的channel数据做线性变换,由48变成C,即图像shape再由 [H/4, W/4, 48]变成了 [H/4, W/4, C]。其实在源码中Patch Partition和Linear Embeding就是直接通过一个卷积层实现的,和之前Vision Transformer中讲的 Embedding层结构一模一样。

  • 然后就是通过四个Stage构建不同大小的特征图,除了Stage1中先通过一个Linear Embeding层外,剩下三个stage都是先通过一个Patch Merging层进行下采样(后面会细讲)。然后都是重复堆叠Swin Transformer Block注意这里的Block其实有两种结构,如图(b)中所示,这两种结构的不同之处仅在于一个使用了W-MSA结构,一个使用了SW-MSA结构。而且这两个结构是成对使用的,先使用一个W-MSA结构再使用一个SW-MSA结构。所以你会发现堆叠Swin Transformer Block的次数都是偶数(因为成对使用)。

  • 最后对于分类网络,后面还会接上一个Layer Norm层、全局池化层以及全连接层得到最终输出。图中没有画,但源码中是这样做的。

接下来,在分别对Patch MergingW-MSASW-MSA以及使用到的相对位置偏执(relative position bias)进行详解。关于Swin Transformer Block中的MLP结构和Vision Transformer中的结构是一样的,所以这里也不在赘述,参考


2 Patch Merging详解

前面有说,在每个Stage中首先要通过一个Patch Merging层进行下采样(Stage1除外)。如下图所示,假设输入Patch Merging的是一个4x4大小的单通道特征图(feature map),Patch Merging会将每个2x2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。


3 W-MSA详解

引入Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量。如下图所示,左侧使用的是普通的Multi-head Self-Attention(MSA)模块,对于feature map中的每个像素(或称作token,patch)在Self-Attention计算过程中需要和所有的像素去计算。但在图右侧,在使用Windows Multi-head Self-Attention(W-MSA)模块时,首先将feature map按照MxM(例子中的M=2)大小划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。


两者的计算量具体差多少呢?原论文中有给出下面两个公式,这里忽略了Softmax的计算复杂度。:
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C                 ( 1 ) Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 h w C       ( 2 ) \\Omega (MSA)=4hwC^2 + 2(hw)^2C \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ (1) \\\\ \\Omega (W-MSA)=4hwC^2 + 2M^2hwC\\ \\ \\ \\ \\ (2) Ω(MSA)=4hwC2+2(hw)2C               (1)Ω(WMSA)=4hwC2+2M2hwC     (2)

  • h代表feature map的高度
  • w代表feature map的宽度
  • C代表feature map的深度
  • M代表每个窗口(Windows)的大小

这个公式是咋来的,原论文中并没有细讲,这里简单说下。首先回忆下单头Self-Attention的公式,如果对Self-Attention不了解的,请看下我之前写的文章
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d ) V Attention(Q, K, V)=\\rm SoftMax(\\fracQK^T\\sqrt d)V Attention(Q,K,V)=SoftMax(d QKT)V

MSA模块计算量

对于feature map中的每个像素(或称作token,patch),都要通过 W q , W k , W v W_q, W_k, W_v Wq,Wk,Wv生成对应的query(q),key(k)以及value(v)。这里假设q, k, v的向量长度与feature map的深度C保持一致。那么对应所有像素生成Q的过程如下式:
A h w × C ⋅ W q C × C = Q h w × C A^hw \\times C \\cdot W^C \\times C_q=Q^hw \\times C Ahw×CWqC×C=Qhw×C

  • A h w × C A^hw \\times C Ahw×C为将所有像素(token)拼接在一起得到的矩阵(一共有hw个像素,每个像素的深度为C)
  • W q C × C W^C \\times C_q WqC×C为生成query的变换矩阵
  • Q h w × C Q^hw \\times C Qhw×C为所有像素通过 W q C × C W^C \\times C_q WqC×C得到的query拼接后的矩阵

根据矩阵运算的计算量公式可以得到生成Q的计算量为 h w × C × C hw \\times C \\times C hw×C×C,生成K和V同理都是 h w C 2 hwC^2 hwC2,那么总共是 3 h w C 2 3hwC^2 3hwC2。接下来 Q Q Q K T K^T KT相乘,对应计算量为 ( h w ) 2 C (hw)^2C (hw)2C
Q h w × C ⋅ K T ( C × h w ) = X h w × h w Q^hw \\times C \\cdot K^T(C \\times hw)= X^hw \\times hw Qhw×CKT(C×hw)=Xhw×hw
接下来忽略除以 d \\sqrt d d 以及softmax的计算量,假设得到 Λ h w × h w \\Lambda ^hw \\times hw Λhw×hw,最后还要乘以V,对应的计算量为 ( h w ) 2 C (hw)^2C (hw)2C:
Λ h w × h w ⋅ V h w × C = B h w × C \\Lambda ^hw \\times hw \\cdot V^hw \\times C=B^hw \\times C Λhw×hwVhw×C=Bhw×C
那么对应单头的Self-Attention模块,总共需要 3 h w C 2 + ( h w ) 2 C + ( h w ) 2 C = 3 h w C 2 + 2 ( h w ) 2 C 3hwC^2 + (hw)^2C + (hw)^2C=3hwC^2 + 2(hw)^2C 3hwC2+(hw)2C+(hw)2C=3hwC2+2(hw)2C。而在实际使用过程中,使用的是多头的Multi-head Self-Attention模块,在之前的文章中有进行过实验对比,多头注意力模块相比单头注意力模块的计算量仅多了最后一个融合矩阵 W O W_O WO的计算量 h w C 2 hwC^2 hwC2
B h w × C ⋅ W O C × C = O h w × C B^hw \\times C \\cdot W_O^C \\times C = O^hw \\times C Bhw×CWOC×C=Ohw×C
所以总共加起来是: 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C pytorch 笔记: Swin-Transformer 代码

Swin-Transformer 图像分割实战:使用Swin-Transformer-Semantic-Segmentation训练ADE20K数据集(语义分割)

改进YOLOv5系列:增加Swin-Transformer小目标检测头

详解docker桥接网络模型

LeNet详解

Linux系统编程---dup和dup2详解