tensorflow中使用Batch Normalization

Posted pclover11

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow中使用Batch Normalization相关的知识,希望对你有一定的参考价值。

在深度学习中为了提高训练速度,经常会使用一些正正则化方法,如L2、dropout,后来Sergey Ioffe 等人提出Batch Normalization方法,可以防止数据分布的变化,影响神经网络需要重新学习分布带来的影响,会降低学习速率,训练时间等问题。提出使用batch normalization方法,使输入数据分布规律保持一致。实验证明可以提升训练速度,提高识别精度。下面讲解一下在Tensorflow中如何使用Batch Normalization

有关Batch Normalization详细内容请查看论文:

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

 

关键函数

tf.layers.batch_normalization、tf.contrib.layers.batch_norm

这两个函数用法一致,以 tf.layers.batch_normalization 为例进行讲解

layer1_conv = tf.layers.batch_normalization(layer1_conv,axis=0,training=in_training)

其中 axis 参数表示沿着哪个轴进行正则化,一般而言Tensor是[batch, width_x, width_y, channel],如果是[width_x, width_y, channel,batch]则axis应该设为3

 

1 在训练阶段

训练的时候需要注意两点,(1)输入参数training=True,(2)计算loss时,要添加以下代码(即添加update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)

 update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_op):
    train_op = optimizer.minimize(loss)

 

2 在测试阶段

测试时需要注意一点,输入参数training=False,

 

以上是关于tensorflow中使用Batch Normalization的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow Batch normalization函数

在tensorflow中使用batch normalization

TensorFlow中random_normal和truncated_normal的区别

Linux 内核调度器 ⑨ ( Linux 内核调度策略 | SCHED_NORMAL 策略 | SCHED_FIFO 策略 | SCHED_NORMAL 策略 | SCHED_BATCH策略 )

tensorflow:batch and shuffle_batch

Tensorflow:如何使用来自 cifar10 的 tf.train.batch 绘制小批量?