注意力机制 SE-Net 原理与 TensorFlow2.0 实现
Posted K同学啊
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了注意力机制 SE-Net 原理与 TensorFlow2.0 实现相关的知识,希望对你有一定的参考价值。
文章目录
🍵 介绍
SENet 是 ImageNet 2017(ImageNet 收官赛)的冠军模型,是由WMW团队发布。具有复杂度低,参数少和计算量小的优点。且SENet 思路很简单,很容易扩展到已有网络结构如 Inception 和 ResNet 中。
🍛 SE 模块
已经有很多工作在空间维度上来提升网络的性能,如 Inception 等,而 SENet 将关注点放在了特征通道之间的关系上。其具体策略为:通过学习的方式来自动获取到每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征并抑制对当前任务用处不大的特征,这又叫做“特征重标定”策略。具体的 SE 模块如下图所示:
给定一个输入 x x x ,其特征通道数为 c 1 c_1 c1,通过一系列卷积等一般变换 F t r F_tr Ftr 后得到一个特征通道数为 c 2 c_2 c2 的特征。与传统的卷积神经网络不同,我们需要通过下面三个操作来重标定前面得到的特征。
-
首先是 Squeeze 操作,我们顺着空间维度来进行特征压缩,将一个通道中整个空间特征编码为一个全局特征,这个实数某种程度上具有全局的感受野,并且输出的通道数和输入的特征通道数相等,例如将形状为 (1, 32, 32, 10) 的 feature map 压缩成 (1, 1, 1, 10)。此操作通常采用采用
global average pooling
来实现。 -
得到了全局描述特征后,我们进行 Excitation 操作来抓取特征通道之间的关系,它是一个类似于循环神经网络中门的机制:
这里采用包含两个全连接层的 bottleneck 结构,即中间小两头大的结构:其中第一个全连接层起到降维的作用,并通过 ReLU 激活,第二个全连接层用来将其恢复至原始的维度。进行 Excitation 操作的最终目的是为每个特征通道生成权重,即学习到的各个通道的激活值(sigmoid 激活,值在 0~1 之间)。 -
最后是一个 Scale 的操作,我们将 Excitation 的输出的权重看做是经过特征选择后的每个特征通道的重要性,然后通过乘法逐通道加权到先前的特征上,完成在通道维度上的对原始特征的重标定,从而使得模型对各个通道的特征更有辨别能力,这类似于attention机制。
🥡 SE 模块应用分析
SE模块的灵活性在于它可以直接应用现有的网络结构中。以 Inception 和 ResNet 为例,我们只需要在 Inception 模块或 Residual 模块后添加一个 SE 模块即可。具体如下图所示:
上图分别是将 SE 模块嵌入到 Inception 结构与 ResNet 中的示例,方框旁边的维度信息代表该层的输出,
r
r
r 表示 Excitation 操作中的降维系数。
🍘 SE 模型效果对比
SE 模块很容易嵌入到其它网络中,为了验证 SE 模块的作用,在其它流行网络如 ResNet 和 Inception 中引入 SE 模块,测试其在 ImageNet 上的效果,如下表所示:
首先看一下网络的深度对 SE 的影响。上表分别展示了 ResNet-50、ResNet-101、ResNet-152 和嵌入 SE 模型的结果。第一栏 Original 是原作者实现的结果,为了进行公平的比较,重新进行了实验得到 Our re-implementation 的结果。最后一栏 SE-module 是指嵌入了 SE 模块的结果,它的训练参数和第二栏 Our re-implementation 一致。括号中的红色数值是指相对于 Our re-implementation 的精度提升的幅值。
从上表可以看出,SE-ResNets 在各种深度上都远远超过了其对应的没有SE的结构版本的精度,这说明无论网络的深度如何,SE模块都能够给网络带来性能上的增益。值得一提的是,SE-ResNet-50 可以达到和ResNet-101 一样的精度;更甚,SE-ResNet-101 远远地超过了更深的ResNet-152。
上图展示了ResNet-50 和 ResNet-152 以及它们对应的嵌入SE模块的网络在ImageNet上的训练过程,可以明显地看出加入了SE模块的网络收敛到更低的错误率上。
🥙 SE 模块代码实现
import tensorflow as tf
class Squeeze_excitation_layer(tf.keras.Model):
def __init__(self, filter_sq):
# filter_sq 是 Excitation 中第一个卷积过程中卷积核的个数
super().__init__()
self.filter_sq = filter_sq
self.avepool = tf.keras.layers.GlobalAveragePooling2D()
self.dense = tf.keras.layers.Dense(filter_sq)
self.relu = tf.keras.layers.Activation('relu')
self.sigmoid = tf.keras.layers.Activation('sigmoid')
def call(self, inputs):
squeeze = self.avepool(inputs)
excitation = self.dense(squeeze)
excitation = self.relu(excitation)
excitation = tf.keras.layers.Dense(inputs.shape[-1])(excitation)
excitation = self.sigmoid(excitation)
excitation = tf.keras.layers.Reshape((1, 1, inputs.shape[-1]))(excitation)
scale = inputs * excitation
return scale
SE = Squeeze_excitation_layer(16)
inputs = np.zeros((1, 32, 32, 32), dtype=np.float32)
SE(inputs).shape
TensorShape([1, 32, 32, 32])
🍜 SE 模块插入到 DenseNet 代码实现
from tensorflow.keras.models import Model
from tensorflow.keras import layers
from tensorflow.keras import backend
def dense_block(x, blocks, name):
for i in range(blocks):
x = conv_block(x, 32, name=name + '_block' + str(i + 1))
return x
def conv_block(x, growth_rate, name):
bn_axis = 3
x1 = layers.BatchNormalization(axis=bn_axis,
epsilon=1.001e-5,
name=name + '_0_bn')(x)
x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
x1 = layers.Conv2D(4 * growth_rate, 1,
use_bias=False,
name=name + '_1_conv')(x1)
x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_1_bn')(x1)
x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
x1 = layers.Conv2D(growth_rate, 3,
padding='same',
use_bias=False,
name=name + '_2_conv')(x1)
x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
return x
def transition_block(x, reduction, name):
bn_axis = 3
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_bn')(x)
x = layers.Activation('relu', name=name + '_relu')(x)
x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1,
use_bias=False,
name=name + '_conv')(x)
x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)
return x
def DenseNet(blocks, input_shape=None, classes=1000, **kwargs):
img_input = layers.Input(shape=input_shape)
bn_axis = 3
# 224,224,3 -> 112,112,64
x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
x = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x)
x = layers.Activation('relu', name='conv1/relu')(x)
# 112,112,64 -> 56,56,64
x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
x = layers.MaxPooling2D(3, strides=2, name='pool1')(x)
# 56,56,64 -> 56,56,64+32*block[0]
# Densenet121 56,56,64 -> 56,56,64+32*6 == 56,56,256
x = dense_block(x, blocks[0], name='conv2')
# 56,56,64+32*block[0] -> 28,28,32+16*block[0]
# Densenet121 56,56,256 -> 28,28,32+16*6 == 28,28,128
x = transition_block(x, 0.5, name='pool2')
# 28,28,32+16*block[0] -> 28,28,32+16*block[0]+32*block[1]
# Densenet121 28,28,128 -> 28,28,128+32*12 == 28,28,512
x = dense_block(x, blocks[1], name='conv3')
# Densenet121 28,28,512 -> 14,14,256
x = transition_block(x, 0.5, name='pool3')
# Densenet121 14,14,256 -> 14,14,256+32*block[2] == 14,14,1024
x = dense_block(x, blocks[2], name='conv4')
# Densenet121 14,14,1024 -> 7,7,512
x = transition_block(x, 0.5, name='pool4')
# Densenet121 7,7,512 -> 7,7,256+32*block[3] == 7,7,1024
x = dense_block(x, blocks[3], name='conv5')
# 加SE注意力机制
x = Squeeze_excitation_layer(16)(x)
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
x = layers.Activation('relu', name='relu')(x)
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
inputs = img_input
if blocks == [6, 12, 24, 16]:
model = Model(inputs, x, name='densenet121')
elif blocks == [6, 12, 32, 32]:
model = Model(inputs, x, name='densenet169')
elif blocks == [6, 12, 48, 32]:
model = Model(inputs, x, name='densenet201')
else:
model = Model(inputs, x, name='densenet')
return model
def DenseNet121(input_shape=[224,224,3], classes=3, **kwargs):
return DenseNet([6, 12, 24, 16], input_shape, classes, **kwargs)
def DenseNet169(input_shape=[224,224,3], classes=3, **kwargs):
return DenseNet([6, 12, 32, 32], input_shape, classes, **kwargs)
def DenseNet201(input_shape=[224,224,3], classes=3, **kwargs):
return DenseNet([6, 12, 48, 32], input_shape, classes, **kwargs)
参考文章:
以上是关于注意力机制 SE-Net 原理与 TensorFlow2.0 实现的主要内容,如果未能解决你的问题,请参考以下文章