Xception

Posted flymanjb

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Xception相关的知识,希望对你有一定的参考价值。

目录

论文: Xception: Deep Learning with Depthwise Separable Convolutions

论文地址: https://arxiv.org/abs/1610.02357

代码地址:

  1. Keras: https://github.com/yanchummar/xception-keras

参考博客:

  1. Xception算法详解
  2. Depthwise卷积与Pointwise卷积

1. 提出背景

Xception 是 google 继 Inception 后提出的对 Inception v3 的另一种改进,主要采用 depthwise separable convolution来替代原来的 Inception v3 中的卷积操作.

1. Inception模块简介

Inception v3 的结构图如下Figure1:

技术分享图片

当时提出Inception 的初衷可以认为是:

特征提取和传递可以通过 1x1,3x3,5x5 conv以及pooling,究竟哪种提取特征方式好呢,Inception 结构将这个疑问留给网络自己训练,也就是将一个输入同时输出给这几种特征提取方式,然后做Concatnate.

  • Inception v3 和 Inception v1 主要的区别是将 5x5的卷积核换成了2个 3x3 卷积核的叠加.

2. Inception模型简化

正如之前论文RexNeXt所说,Inception网络太依赖于人工设计了。于是结合ResNeXt的思想,从Inception V3联想到简化的Inception结构,就是Figure 2.

技术分享图片

3. Inception模型拓展

我们可以做个等效的变换,事实上效果是一样的,有了Figure 3:

技术分享图片

Figure 3 表示对于一个输入,先用一个统一的 1x1 的卷积核卷积,然后再连接3个 3x3卷积,这三个操作只将前面 1x1 卷积结果的一部分作为自己的输入(这里是1/3channel)的卷积.

既然如此,不如干脆点:

3x3的卷积核的个数和 1x1的输出channel 一样多,每个 3x3卷积都只和1个输入的channel做卷积.

技术分享图片

2.论文核心

2.1 Depthwise Separable Convolution 深度分离卷积

DepthWise卷积PointWise卷积,合起来称作DepthWise Separable Convolution,该结构和常规卷积操作类似,可用来提取特征。但是相比较常规卷积操作,其参数量和运算成本较低,所以在一些轻量级网络中会碰到这种结构,比如说MobileNet.

2.1.1 常规卷积操作

技术分享图片

对于一张 5x5 像素,三通道彩色输入图片(5x5x3),经过3x3卷积核的层,假设输出通道数量为4,则卷积核的shape为 3x3x3x4,最终输出4个Feature Map.

  • 如果为padding=same,则特征图尺寸为5x5
  • padding=valid,特征图尺寸3x3

2.1.2 DepthWise Convolution

不同于常规卷积操作:

DepthWise Convolution的一个卷积核只负责一个通道,一个通道只能被一个卷积核卷积.

技术分享图片

同样对于这张5x5像素,三通道的彩色输入图片,DepthWise Convolution首先经过第一次卷积运算,不同于上面的常规卷积

DepthWise Convolution完全在二维平面进行,卷积核数量与上一层必须一致.(上一层通道与卷积核个数一致)

所以一个三通道的图像经过运算后生成了3个Feature Map.

但是这就存在一个缺点,首先:

  1. DepthWise Convolution 完成后 Feature Map 数量和输入层的通道数量相同,无法拓展Feature Map.
  2. 这种运算对输入层的每个通道独立进行卷积运算,没有有效利用在相同空间位置上的 feature 信息.

因此采用 PointWise Convolution 来将这些Feature Map重新组合生成新的 Feature Map.

2.1.3 PointWise Convolution

PointWise Convolution 的运算与常规卷积运算非常相似,它的卷积核的尺寸为 1x1xM,其中M为上一层的通道数量:

技术分享图片

这里PointWise Convolution 运算会将上一步的map在深度方向上进行加权组合,生成新的Feature map,有多少个卷积核就有多少个输出Feature Maps.

有意思的是其实之前许多网络,例如Inception v3用PointWise Convolution来做维度缩减来降低参数,这里用来联系以及拓展Feature Maps。而DepthWise Convolution也并不是新出现的,它可以看做是分组卷积的特例,早在AlexNet就出现过.

3. 网络结构

技术分享图片

Xception作为Inception v3的改进,主要是在Inception v3的基础上引入了depthwise separable convolution,在基本不增加网络复杂度的前提下提高了模型的效果.

疑问

  1. 有些人会好奇为什么引入depthwise separable convolution没有大大降低网络的复杂度?

原因在于作者加宽了网络,使得参数数量和Inception v3差不多,然后在这前提下比较性能.因此Xception目的不在于模型压缩,而是提高性能.

4. 核心代码

from keras.models import Model
from keras.layers import Dense, Input, BatchNormalization, Activation, add
from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.utils import plot_model


def Xception():
    input_shape = _obtain_input_shape(None, default_size=299, min_size=71, data_format=‘channels_last‘, require_flatten=True)
    img_input = Input(shape=input_shape)

    # Block 1
    x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False)(img_input)
    x = BatchNormalization()(x)
    x = Activation(‘relu‘)(x)
    x = Conv2D(64, (3, 3), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation(‘relu‘)(x)

    residual = Conv2D(128, (1, 1), strides=(2, 2), padding=‘same‘, use_bias=False)(x)
    residual = BatchNormalization()(residual)

    # Block 2
    x = SeparableConv2D(128, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation(‘relu‘)(x)
    x = SeparableConv2D(128, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), padding=‘same‘)(x)

    x = add([x, residual])

    residual = Conv2D(256, (1, 1), strides=(2, 2), padding=‘same‘, use_bias=False)(x)
    residual = BatchNormalization()(residual)

    # Block 3
    x = Activation(‘relu‘)(x)
    x = SeparableConv2D(256, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation(‘relu‘)(x)
    x = SeparableConv2D(256, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), padding=‘same‘)(x)

    x = add([x, residual])

    residual = Conv2D(728, (1, 1), strides=(2, 2), padding=‘same‘, use_bias=False)(x)
    residual = BatchNormalization()(residual)

    # Block 4
    x = Activation(‘relu‘)(x)
    x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation(‘relu‘)(x)
    x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), padding=‘same‘)(x)

    x = add([x, residual])

    # Block 5-12
    for i in range(8):
        residual = x
        x = Activation(‘relu‘)(x)
        x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
        x = BatchNormalization()(x)
        x = Activation(‘relu‘)(x)
        x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
        x = BatchNormalization()(x)
        x = Activation(‘relu‘)(x)
        x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
        x = BatchNormalization()(x)

        x = add([x, residual])

    residual = Conv2D(1024, (1, 1), strides=(2, 2), padding=‘same‘, use_bias=False)(x)
    residual = BatchNormalization()(residual)

    # Block 13
    x = Activation(‘relu‘)(x)
    x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = SeparableConv2D(1024, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = BatchNormalization()(x)

    # Block 13 Pool
    x = MaxPooling2D((3, 3), strides=(2, 2), padding=‘same‘)(x)
    x = add([x, residual])

    # Block 14
    x = SeparableConv2D(1536, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation(‘relu‘)(x)

    # Block 14 part2
    x = SeparableConv2D(2048, (3, 3), padding=‘same‘, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation(‘relu‘)(x)

    # 全链接层
    x = GlobalAveragePooling2D()(x)
    x = Dense(1000, activation=‘softmax‘)(x)

    return Model(inputs=img_input, outputs=x, name=‘xception‘)


if __name__ == ‘__main__‘:
    model = Xception()
    model.summary()
    plot_model(model, show_shapes=True)

技术分享图片

以上是关于Xception的主要内容,如果未能解决你的问题,请参考以下文章

深度可分离卷积网络Xception 网络解析

Xception

Xception网络结构理解

Xception(图像分类)中的损失和准确性没有提高

JSP慕课网之applicationpagepageContextconfigexception

Xception实现动物识别(TensorFlow)