Keras - 没有停止和恢复训练的好方法?

Posted

技术标签:

【中文标题】Keras - 没有停止和恢复训练的好方法?【英文标题】:Keras - no good way to stop and resume training? 【发布时间】:2020-12-25 18:41:21 【问题描述】:

经过大量研究,似乎没有什么好的方法可以正确使用 Tensorflow 2 / Keras 模型停止和恢复训练。无论您是使用model.fit()还是使用自定义训练循环,这都是正确的。

似乎有 2 种受支持的方法可以在训练时保存模型:

    只保存模型的权重,使用model.save_weights()save_weights_only=Truetf.keras.callbacks.ModelCheckpoint。这似乎是我见过的大多数示例的首选,但是它有许多主要问题:

    优化器状态未保存,意味着训练恢复将不正确。 学习率计划已重置 - 这对某些模型来说可能是灾难性的。 Tensorboard 日志返回到第 0 步 - 除非实施复杂的变通办法,否则日志基本上无用。

    使用model.save()save_weights_only=False 保存整个模型、优化器等。优化器状态已保存(良好),但仍存在以下问题:

    Tensorboard 日志仍返回第 0 步 学习率计划仍在重置 (!!!) 无法使用自定义指标。 这在使用自定义训练循环时根本不起作用 - 自定义训练循环使用非编译模型,并且似乎不支持保存/加载非编译模型。

我发现的最佳解决方法是使用自定义训练循环,手动保存步骤。这修复了 tensorboard 日志记录,并且可以通过执行类似 keras.backend.set_value(model.optimizer.iterations, step) 的操作来修复学习率计划。但是,由于完整的模型保存不在表中,因此不会保留优化器状态。我看不出有办法独立保存优化器的状态,至少不需要做很多工作。而且像我一样搞乱 LR 的日程安排也让人觉得很麻烦。

我错过了什么吗?人们如何使用此 API 保存/恢复?

【问题讨论】:

你是对的,没有内置的 API 支持可恢复性 - 这正是促使我创建 my own 的原因。应该会在几周内发布。 我相信在使用 model.save 时可以使用自定义指标,因为 load_model 函数的 custom_objects 参数。我觉得学习率计划可以很容易地手动实现,就像你说的那样,或者甚至只是通过获取 model.fit 输出的长度,然后在下次函数时做一些数学来调整结果调用。 @Arkleseisure RE:指标 - 看起来使用 custom_objects 应该可以工作,但不幸的是它没有。 custom_objects 不支持指标 atm。 抱歉,我知道它适用于损失函数,但错误地认为它适用于指标。 【参考方案1】:

你是对的,没有内置的可恢复性支持——这正是我创建 DeepTrain 的动力。这就像 TensorFlow/Keras 的 Pytorch Lightning(在不同方面越来越好)。

为什么是另一个库?我们还不够吗?你没有这样的东西;如果有,我不会建造它。 DeepTrain 专为“保姆式”训练量身定制:训练更少的模型,但要彻底训练它们。密切监控每个阶段以诊断问题所在以及如何解决。

灵感来自我自己的使用;我会在整个漫长的时期中看到“验证峰值”,并且无法暂停,因为它会重新启动时期或以其他方式破坏火车循环。并且忘记知道您适合哪个批次,或者还剩下多少。

与 Pytorch Lightning 相比如何? 卓越的可恢复性和内省性,以及独特的火车调试实用程序 - 但 Lightning 在其他方面表现更好。我在工作中有一个全面的列表比较,将在一周内发布。

Pytorch 支持即将到来?也许吧。如果我说服 Lightning 开发团队弥补其相对于 DeepTrain 的缺点,那么不会——否则可能。同时,您可以探索Examples的图库。


小例子

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from deeptrain import TrainGenerator, DataGenerator

ipt = Input((16,))
out = Dense(10, 'softmax')(ipt)
model = Model(ipt, out)
model.compile('adam', 'categorical_crossentropy')

dg  = DataGenerator(data_path="data/train", labels_path="data/train/labels.npy")
vdg = DataGenerator(data_path="data/val",   labels_path="data/val/labels.npy")
tg  = TrainGenerator(model, dg, vdg, epochs=3, logs_dir="logs/")

tg.train()

您可以随时KeyboardInterrupt,检查模型、训练状态、数据生成器 - 并恢复。

【讨论】:

【参考方案2】:

tf.keras.callbacks.experimental.BackupAndRestoretensorflow>=2.3 添加了用于从中断中恢复训练的 API。根据我的经验,它非常有效。

参考: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore

【讨论】:

以上是关于Keras - 没有停止和恢复训练的好方法?的主要内容,如果未能解决你的问题,请参考以下文章

Keras:恢复训练的加载检查点模型会降低准确性吗?

恢复培训 tf.keras Tensorboard

AI - TensorFlow - 示例05:保存和恢复模型

Keras - 管理历史

保存和恢复陷阱状态?管理多个陷阱处理程序的简单方法?

AVAudioPlayer 在停止或恢复时“滴答”