MS-Model:nnFormer

Posted zzzyzh

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了MS-Model:nnFormer相关的知识,希望对你有一定的参考价值。

文章目录


前言

本文在医学图像分割领域中的另一个十分常用的基线网络 nnUNet 的基础上修改得到,在多器官分割任务(十项全能数据集)上取得了十分不错的成绩

原论文链接:nnFormer: Interleaved Transformer for Volumetric Segmentation

论文复现参考:MS-Train【2】:nnFormer

本文中设计到的 3 个重要模型可以参考我的其他 blog:
CV-Model【6】:Vision Transformer
CV-Model【7】:Swin Transformer
MS-Model【1】:nnU-Net


1. Abstract & Introduction

1.1. Abstract

目前的方法要么不采用 Transformer,要么使用 Transformer 的效率不够高,无法捕捉医学成像中的长期依赖性

nnFormer 不仅利用交错卷积和自我注意操作的结合,而且还引入了基于局部和全局体积的自我注意机制来学习体积表示。此外,nnFormer 提出使用跳过注意力来取代传统的类似 U-Net 架构中跳过连接的串联/求和操作

这项任务是对三维计算机断层扫描(CT)中捕获的不同器官进行分割

1.2. Introduction

过往的一些主流模型通常将 ConvNets 作为主体,在此基础上进一步应用转化器来捕捉长期的依赖关系,但这样无法充分的发挥 Transformer 的优势。换句话说,一到两层的变换器不足以将长期依赖关系与卷积表征纠缠在一起,而卷积表征通常包含精确的空间信息并提供分层的概念

本文在技术上的主要贡献:

  • 卷积和自我注意操作的交错组合
  • 利用基于局部和全局体积的自我注意,分别建立特征金字塔和提供大的感受野
  • 提出跳过注意,以取代跳过连接中的传统连接 / 求和操作

1.3. Related work

由于 Transformer 本身可以有效地捕捉和利用像素或体素之间的长期依赖(long-term dependencies),近期出现了非常多结合 CNNTransformer 的针对医疗影像处理的模型和网络。其中大部分结果表明,在 CNN 中合适的位置嵌入类 Transformer 的结构,可以有效地提升网络的性能

基于 Transformer 的医疗影像处理模型和网络通常可以分为两类:

  • 仍然使用 CNN 作为主要的特征提取器,辅以类 Transformer 结构以捕捉特征中的全局信息,再将此信息嵌入到 CNN
    • nnU-Net
      • 目前性能最好的全卷积医学分割神经网络
      • nnU-Net 是 U-Net 架构的集合体,具有数据预处理、数据增强和后处理的自动化管道
      • 对二维窗口 patches 比三维体积 patches 效果更好
    • TransUNet
      • 第一个提出的架构,在医学图像分割的背景下利用 Transformer
      • Convnets 被设计为特征提取器,Transformer 层被覆盖以帮助编码全局背景
    • Swin-UNet
      • 在一个类似 U-Net 的架构中使用一个编码器-解码器
      • Swin-UNet 使用 ConvNets 中使用的特征金字塔,然后在其上设置 Transformer
  • 直接使用纯 Transformer 结构进行处理
    • Convolution-free medical image segmentation using transformers
      • 首次引入了无卷积的分割模型,将扁平化的图像表示转发给 transformers
      • 输出被重组为三维张量,与分割掩码对齐

相关工作的缺点:

  • Transformer 的优势没有得到充分的利用,几层 Transformer 不足以纠缠长期的依赖关系
  • 由于卷积表征包含精确的空间信息,这种信息在一组多幅图像(三维斑块的二维窗口)上会丢失
  • 大多数方法将卷积网作为基础特征提取器,Transformer 只在顶部应用,以帮助从卷积网中提取的特征向量编码全局背景
  • 只使用变换器,通过直接压平原始像素和应用一维预处理并不能提供足够丰富的特征集来建立模型
  • ConvNets 是图像数据的首选工具,因为它们能捕获精确的局部特征,因此需要将它们纳入模型

nnFormer 的优势:

  • 混合 stem
  • 卷积和自关注交错使用,以充分发挥它们的优势
    • Convolution:捕捉精确的局部信息。
    • Self-Attention:捕捉长期的依赖关系

2. Method

2.1. Overview

nnFormer 的整体架构如上图所示,它保持了与 U-Net 类似的 U 型结构,主要由三部分组成,即 EncoderBottleneckDecoder

  • Encoder 包括一个嵌入层、两个局部 transformer 块(每个块包含两个连续的层)和两个下采样层
  • 对称的是,Decoder 分支包括两个 transformer 块,两个上采样层和最后一个用于进行掩码预测的补丁扩展层
  • Bottleneck 部分包括一个下采样层、一个上采样层和三个全局 transformer 块,用于提供大的接收场以支持 Decoder

受 U-Net 的启发,本文在 EncoderDecoder 的相应特征金字塔之间以对称的方式添加了跳过连接,这有助于恢复预测中的细粒度细节。然而,与通常使用求和或串联操作的非典型跳过连接不同,本文引入了跳过关注来弥补 EncoderDecoder 之间的差距

Fig 2 图 a 中的 nnFormer 的详细结构如下图所示:

2.2. Encoder

nnFormer 的输入是一个三维补丁 X ∈ R H × W × D X \\in R^H \\times W \\times D XRH×W×D(通常是从原始图像中随机裁剪的),参数含义:

  • H , W , D H, W, D H,W,D 分别表示每个输入扫描的高度、宽度和深度

2.2.1. Components

nnFormer 使用混合 stem,其中卷积和自我注意被交错使用,以充分发挥它们各自的优势

  • 把一个轻量级的 Convolutional embedding layer 放在 Transformer block 的前面
    • 这个嵌入层对精确的像素级空间信息进行编码,并提供低水平但高分辨率的三维特征
  • 在嵌入块之后,Transformer block 和卷积下采样块交错在一起使用
    • 以充分融合不同尺度的高层次和分层物体概念的长期依赖关系,这有助于提高学习表征的泛化能力和稳健性

2.2.2. The embedding layer

Embedding block 将每个输入扫描 X X X 转化为高维张量 X e ∈ R H 4 × W 4 × D 2 × C X_e \\in R^\\fracH4 \\times \\fracW4 \\times \\fracD2 \\times C XeR4H×4W×2D×C
参数含义:

  • H 4 × W 4 × D 2 \\fracH4 \\times \\fracW4 \\times \\fracD2 4H×4W×2D 代表补丁标记的数量
  • C C C 代表序列长度(这些数字在不同的数据集上可能略有不同)

ViTSwin Transformer 在嵌入块中使用大的卷积核来提取特征不同,本文发现应用小的卷积核的连续卷积层在初始阶段带来更多的好处:

  • 应用连续的卷积层
    • 在嵌入块中使用卷积层,因为它们对像素级的空间信息进行编码,比变换器中使用的补丁式位置编码更精确
  • 小尺寸核
    • 与大尺寸的内核相比,小的内核尺寸有助于降低计算的复杂性,同时提供同等大小的感受野

上图所示的 Embedding block 是一个四层的卷积结构(针对不同数据集参数上可能会有出入,具体参考 Fig 2 图 b)

  • 核大小为 3
  • 在每个卷积层之后(除了最后一个),附加一个 GELU 激活函数和一个 layer normalization

Embedding block 主要用来将输入的影像转化为网络可以处理的特征。使用四层的卷积来处理输入的原因如下:

  • 卷积网络可以更好的保留更加精确的位置信息
  • 卷积操作可以提供高分辨率的底层特征,这是后面应用 Transformer block 的基础

2.2.3. Local Volume-based Multi-head Self-attention (LV-MSA)

nnFormer 在三维局部体积内计算 self-attention

假设 X L V ∈ R L × C X_LV \\in R^L \\times C XLVRL×C 代表 local transformer block 的输入

  • 首先被重塑为 X ^ L V ∈ R N L V × N T × C \\hatX_LV \\in R^N_LV \\times N_T \\times C X^LVRNLV×NT×C
    • N L V N_LV NLV 是预先定义的三维局部
    • N T = S H × S W × S D N_T = S_H \\times S_W \\times S_D NT=SH×SW×SD 表示每个 volume 中补丁标记的数量
    • S H , S W , S D \\ S_H, S_W, S_D \\ SH,SW,SD 代表局部 volume 的大小

如下图所示:在每个区块中进行两个连续的 transformer 层,其中第二层可以被视为第一层的移位版本(即 SLV-MSA

计算过程可以总结为以下几点:

l l l 代表层的索引, M L P MLP MLP 代表多层感知机

LV-MSA 在一个 h × w × d h \\times w \\times d h×w×d 的 patches 体积上的计算复杂度为:

SLV-MSALV-MSA 中使用的三维局部体积置换为 ( ⌊ S H 2 ⌋ , ⌊ S W 2 ⌋ , ⌊ S D 2 ⌋ ) (\\lfloor \\fracS_H2 \\rfloor, \\lfloor \\fracS_W2 \\rfloor, \\lfloor \\fracS_D2 \\rfloor) (⌊2SH,2SW,2SD⌋),以引入不同局部体积之间的更多相互作用

在实践中,SLV-MSA 的计算复杂度与 LV-MSA 相似

相较于传统的 voxel 和 voxel 之间计算 self-attention 的方式,LV-MSA 可以大大地降低计算的复杂度,这些降低的复杂度主要集中在网络早期的计算过程中,伴随着特征空间维度的下降 ( H , W , D ) (H, W, D) (H,W,D) 以及通道输入 ( C ) (C) (C) 的增多,其实这种优势就不明显了

每个三维局部体中 query-key-value (QKV) attention 可以通过以下公式计算:

参数含义:

  • Q , K , V ∈ R N T × d k Q, K, V \\in R^N_T \\times d_k Q,K,VRNT×dk 表示 query,key 和 value 的矩阵
  • B ∈ R N T B \\in R^N_T BRNT 是相对位置编码

2.2.4. The down-sampling layer

卷积下采样产生了层次化的表示,有助于在多个尺度上对物体概念进行建模

进行下采样的原因:

  • 多次下采样可以建立多尺度的特征金字塔结构
  • 下采样可以大大降低 GPU 显存的消耗

在大多数情况下,下采样层涉及到一个跨度卷积操作,其中跨度在所有维度上都被设置为 2。然而,在实践中,关于特定维度的步长可以设置为 1,因为在这个维度上,切片的数量是有限的,过度下采样(即使用大的下采样步长)可能是有害的

2.3. Bottleneck

将二维 multi-head self-attention 机制扩展到三维版本,如下图所示:

其计算复杂性可以表述为:

h , w , d \\ h, w, d \\ h,w,d 相比 S H , S W , S D \\ S_H, S_W, S_D \\ SH,SW,SD 较大时,GV-MSA 需要更多的计算资源

Bottleneck 中, h , w , d \\ h, w, d \\ h,w,d 在经过几个下采样层后已经变得小得多,使得它们的乘积,即 h w d hwd hwd, ,具有与 S H S W S D S_H S_W S_D SHSW以上是关于MS-Model:nnFormer的主要内容,如果未能解决你的问题,请参考以下文章

实战霸榜各大医学分割挑战赛的Transformer架构--nnFormer

论文解读丨无监督视觉表征学习的动量对比

论文解读丨无监督视觉表征学习的动量对比

长文回顾 | Microsoft Cloud&AI 李琳婕: 多模态预训练模型UNITER, 通用的图像-文本语言表征学习

带你读论文向量表征经典之DeepWalk

带你读论文向量表征经典之DeepWalk