tensorflow学习率控制及调试
Posted yealxxy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow学习率控制及调试相关的知识,希望对你有一定的参考价值。
在深度学习中,学习率变化对模型收敛的结果影响很大,因此很多时候都需要控制学习率的变化。本文以tensorflow实现learning rate test为例,讲述学习率变化控制的方法,以及怎么调试。
一、learning rate test
学习率测试(learning rate test)是一个找到学习率变化的范围的测试,详情可以查看自 Adam 出现以来,深度学习优化器发生了什么变化
二、tensflow实现学习率测试
- 控制学习率以线性,或者指数形式增长
def lr_test(global_step,min_lr=1e-5, max_lr=1e1, steps_per_epoch=784, epochs=20, linear=False):
if global_step is None:
raise ValueError("global_step is required for cyclic_learning_rate.")
learning_rate = ops.convert_to_tensor(min_lr, name="learning_rate")
dtype = learning_rate.dtype
global_step = math_ops.cast(global_step, dtype)
total_iterations = tf.cast(steps_per_epoch * epochs, dtype)
min_lr = tf.cast(min_lr, dtype)
max_lr = tf.cast(max_lr, dtype)
if linear:
lr_mult = tf.cast(max_lr / min_lr / total_iterations, dtype)
else:
lr_mult = tf.cast((max_lr / min_lr) ** (1 / total_iterations), dtype)
mult = lr_mult * global_step if linear else lr_mult ** global_step
return min_lr * mult
- 创建优化器时,改变学习率
if config.optimizer == 'Adam':
optimizer = tf.train.AdamOptimizer(
learning_rate = lr_test(global_step=self.global_step),
beta1 = config.beta1,
beta2 = config.beta2,
epsilon = config.epsilon
)
- 创建优化过程
opt_op = optimizer.minimize(self.total_loss,global_step=self.global_step)
三、查看学习率变化过程
- 计算每一步更新时的学习率
optimizer = tf.train.AdamOptimizer中有一个属性记录跌倒的学习率。
lr = 0.1
step_rate = 1000
decay = 0.95
global_step = tf.Variable(0, trainable=False)
increment_global_step = tf.assign(global_step, global_step + 1)
learning_rate = tf.train.exponential_decay(lr, global_step, step_rate, decay, staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=0.01)
trainer = optimizer.minimize(loss_function)
# Some code here
print('Learning rate: %f' % (sess.run(trainer ._lr)))
- tensorboard 显示学习率的变化
其他tensorbord代码没什么不同,就不粘出来了。
tf.summary.scalar("lr",self.optimizer._lr)
- 学习率变化的结果
以上是关于tensorflow学习率控制及调试的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow+Keras学习率指数分段逆时间多项式衰减及自定义学习率衰减的完整实例
Tensorflow+Keras学习率指数分段逆时间多项式衰减及自定义学习率衰减的完整实例
TensorFlow使用记录 (三): Learning rate