Xception
Posted flymanjb
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Xception相关的知识,希望对你有一定的参考价值。
目录
论文: Xception: Deep Learning with Depthwise Separable Convolutions
论文地址: https://arxiv.org/abs/1610.02357
代码地址:
参考博客:
1. 提出背景
Xception 是 google 继 Inception 后提出的对 Inception v3 的另一种改进,主要采用 depthwise separable convolution来替代原来的 Inception v3 中的卷积操作.
1. Inception模块简介
Inception v3 的结构图如下Figure1:
当时提出Inception 的初衷可以认为是:
特征提取和传递可以通过 1x1,3x3,5x5 conv以及pooling,究竟哪种提取特征方式好呢,Inception 结构将这个疑问留给网络自己训练,也就是将一个输入同时输出给这几种特征提取方式,然后做Concatnate.
- Inception v3 和 Inception v1 主要的区别是将 5x5的卷积核换成了2个 3x3 卷积核的叠加.
2. Inception模型简化
正如之前论文RexNeXt所说,Inception网络太依赖于人工设计了。于是结合ResNeXt的思想,从Inception V3联想到简化的Inception结构,就是Figure 2.
3. Inception模型拓展
我们可以做个等效的变换,事实上效果是一样的,有了Figure 3:
Figure 3 表示对于一个输入,先用一个统一的 1x1 的卷积核卷积,然后再连接3个 3x3卷积,这三个操作只将前面 1x1 卷积结果的一部分作为自己的输入(这里是1/3channel)的卷积.
既然如此,不如干脆点:
3x3的卷积核的个数和 1x1的输出channel 一样多,每个 3x3卷积都只和1个输入的channel做卷积.
2.论文核心
2.1 Depthwise Separable Convolution 深度分离卷积
DepthWise卷积PointWise卷积,合起来称作DepthWise Separable Convolution,该结构和常规卷积操作类似,可用来提取特征。但是相比较常规卷积操作,其参数量和运算成本较低,所以在一些轻量级网络中会碰到这种结构,比如说MobileNet.
2.1.1 常规卷积操作
对于一张 5x5 像素,三通道彩色输入图片(5x5x3),经过3x3卷积核的层,假设输出通道数量为4,则卷积核的shape为 3x3x3x4,最终输出4个Feature Map.
- 如果为
padding=same
,则特征图尺寸为5x5 padding=valid
,特征图尺寸3x3
2.1.2 DepthWise Convolution
不同于常规卷积操作:
DepthWise Convolution的一个卷积核只负责一个通道,一个通道只能被一个卷积核卷积.
同样对于这张5x5像素,三通道的彩色输入图片,DepthWise Convolution首先经过第一次卷积运算,不同于上面的常规卷积
DepthWise Convolution完全在二维平面进行,卷积核数量与上一层必须一致.(上一层通道与卷积核个数一致)
所以一个三通道的图像经过运算后生成了3个Feature Map.
但是这就存在一个缺点,首先:
- DepthWise Convolution 完成后 Feature Map 数量和输入层的通道数量相同,无法拓展Feature Map.
- 这种运算对输入层的每个通道独立进行卷积运算,没有有效利用在相同空间位置上的 feature 信息.
因此采用 PointWise Convolution 来将这些Feature Map重新组合生成新的 Feature Map.
2.1.3 PointWise Convolution
PointWise Convolution 的运算与常规卷积运算非常相似,它的卷积核的尺寸为 1x1xM,其中M为上一层的通道数量:
这里PointWise Convolution 运算会将上一步的map在深度方向上进行加权组合,生成新的Feature map,有多少个卷积核就有多少个输出Feature Maps.
有意思的是其实之前许多网络,例如Inception v3用PointWise Convolution来做维度缩减来降低参数,这里用来联系以及拓展Feature Maps。而DepthWise Convolution也并不是新出现的,它可以看做是分组卷积的特例,早在AlexNet就出现过.
3. 网络结构
Xception作为Inception v3的改进,主要是在Inception v3的基础上引入了depthwise separable convolution,在基本不增加网络复杂度的前提下提高了模型的效果.
疑问
- 有些人会好奇为什么引入depthwise separable convolution没有大大降低网络的复杂度?
原因在于作者加宽了网络,使得参数数量和Inception v3差不多,然后在这前提下比较性能.因此Xception目的不在于模型压缩,而是提高性能.
4. 核心代码
from keras.models import Model
from keras.layers import Dense, Input, BatchNormalization, Activation, add
from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.utils import plot_model
def Xception():
input_shape = _obtain_input_shape(None, default_size=299, min_size=71, data_format=‘channels_last‘, require_flatten=True)
img_input = Input(shape=input_shape)
# Block 1
x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False)(img_input)
x = BatchNormalization()(x)
x = Activation(‘relu‘)(x)
x = Conv2D(64, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation(‘relu‘)(x)
residual = Conv2D(128, (1, 1), strides=(2, 2), padding=‘same‘, use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 2
x = SeparableConv2D(128, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation(‘relu‘)(x)
x = SeparableConv2D(128, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding=‘same‘)(x)
x = add([x, residual])
residual = Conv2D(256, (1, 1), strides=(2, 2), padding=‘same‘, use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 3
x = Activation(‘relu‘)(x)
x = SeparableConv2D(256, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation(‘relu‘)(x)
x = SeparableConv2D(256, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding=‘same‘)(x)
x = add([x, residual])
residual = Conv2D(728, (1, 1), strides=(2, 2), padding=‘same‘, use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 4
x = Activation(‘relu‘)(x)
x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation(‘relu‘)(x)
x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding=‘same‘)(x)
x = add([x, residual])
# Block 5-12
for i in range(8):
residual = x
x = Activation(‘relu‘)(x)
x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation(‘relu‘)(x)
x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation(‘relu‘)(x)
x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = add([x, residual])
residual = Conv2D(1024, (1, 1), strides=(2, 2), padding=‘same‘, use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 13
x = Activation(‘relu‘)(x)
x = SeparableConv2D(728, (3, 3), padding=‘same‘, use_bias=False)(x)
x = SeparableConv2D(1024, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
# Block 13 Pool
x = MaxPooling2D((3, 3), strides=(2, 2), padding=‘same‘)(x)
x = add([x, residual])
# Block 14
x = SeparableConv2D(1536, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation(‘relu‘)(x)
# Block 14 part2
x = SeparableConv2D(2048, (3, 3), padding=‘same‘, use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation(‘relu‘)(x)
# 全链接层
x = GlobalAveragePooling2D()(x)
x = Dense(1000, activation=‘softmax‘)(x)
return Model(inputs=img_input, outputs=x, name=‘xception‘)
if __name__ == ‘__main__‘:
model = Xception()
model.summary()
plot_model(model, show_shapes=True)
以上是关于Xception的主要内容,如果未能解决你的问题,请参考以下文章