从keras中SGD源码理解基于时间的学习速率衰减decay策略

Posted ybdesire

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了从keras中SGD源码理解基于时间的学习速率衰减decay策略相关的知识,希望对你有一定的参考价值。

1. 引入

在求解神经网络优化问题中,学习速率不应该是固定不变的。最好的学习速率,应该在训练开始时偏大,这样能加快训练速度,并在训练过程中逐步减小,这样能更好的逼近最优点。

所以,在参考1中,我们知道了有SGD, Adam, Adadelta 等这些非常经典的优化算法。

2. decay

从参考2的中,我们能看到如何调节SGD, Adam等优化器的参数。
同时我们也发现了一个参数叫做decay,它表示学习速率的衰减值。
decay可以用于SGD, Adam, RMSprop, Adagrad, Adadelta, Adamax, Nadam等keras提供的优化器
decay的默认值是0,可以设置为任何大于等于零的浮点数。

那么,它的含义是什么呢?

从参考3,参考4中,我们可以看到公式表示decay对learning rate的影响,从这个公式表达上,可以看出如下规律

  • 每一次epoch/iteration训练后,如果decay不为零,都会在不同优化器计算得到的learning rate的基础上,进一步减小(衰减)learning rate
  • decay越大,每一次epoch/iteration训练后,learning rate下降(衰减)的更快
  • decay为0,则每一次epoch/iteration训练后得到的learning rate不会改变

3. SGD中decay计算过程源码分析

从上一节中,我们发现有的资料给出了decay对learning rate影响的公式。本文也通过查找keras源码,在github上搜索关键字“class SGD”,找出了SGD中decay对learning rate计算的核心代码(如下代码是从源码中提取出来并经过精简、注释):

https://github.com/keras-team/keras/blob/master/keras/optimizer_v1.py#L188

class SGD(Optimizer):
  def get_updates(self, loss, params):
    lr = self.lr # 获取上一次计算得到的learning rate
    if self.initial_decay > 0: # 如果decay设置不为0
      lr = lr * (     1. /
          (1. + self.decay * tf.cast(self.iterations, backend.dtype(self.decay))))
        # tf.cast是做类型转换,相当于把self.iterations值转换为和self.decay数据类型相同

所以,learning rate的更新策略,就是

lr = lr * (1. /(1. + decay * iterations))

也就是参考3和参考4中给出的公式。

4. 总结

  1. keras中Adam,SGD等优化器,有一个参数叫做decay
  2. decay是调节“基于时间的学习速率衰减”策略
  3. 调节decay值,只是最简单的一种神经网络学习速率衰减策略
  4. decay值默认为0,表示learning rate不随时间(epoch/iterations)衰减
  5. decay值越大,则每一次epoch/iteration训练后得到的learning rate衰减越快

. 参考

  1. https://blog.csdn.net/ybdesire/article/details/51792925
  2. https://keras.io/zh/optimizers/
  3. https://stats.stackexchange.com/questions/211334/keras-how-does-sgd-learning-rate-decay-work
  4. https://zhuanlan.zhihu.com/p/78096138

以上是关于从keras中SGD源码理解基于时间的学习速率衰减decay策略的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow+Keras学习率指数分段逆时间多项式衰减及自定义学习率衰减的完整实例

Keras 中 Adam 优化器的衰减参数

如何通过遵循全局步骤在Keras中实现指数衰减学习率

参数更新

图像去雾基于matlab颜色衰减先验图像去雾含Matlab源码 2036期

图像去雾基于matlab颜色衰减先验图像去雾含Matlab源码 2036期