tf.keras 如何保存 ModelCheckPoint 对象

Posted

技术标签:

【中文标题】tf.keras 如何保存 ModelCheckPoint 对象【英文标题】:tf.keras how to save ModelCheckPoint object 【发布时间】:2020-04-02 22:19:18 【问题描述】:

ModelCheckpoint 可用于根据特定的监控指标保存最佳模型。所以它显然有关于存储在其对象中的最佳指标的信息。例如,如果您在 google colab 上进行培训,您的实例可能会在没有警告的情况下被终止,并且在长时间的培训课程后您会丢失此信息。

我试图腌制 ModelCheckpoint 对象但得到:

TypeError: can't pickle _thread.lock objects  

这样当我把笔记本带回来时,我可以重复使用这个相同的对象。有没有好的方法来做到这一点?您可以尝试通过以下方式重现:

chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.epoch:02d-val_loss:.4f.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

with open('chkpt_cb.pickle', 'w') as f:
  pickle.dump(chkpt_cb, f, protocol=pickle.HIGHEST_PROTOCOL)

【问题讨论】:

你能发布你正在使用的当前代码块吗? ModelCheckpoint 通常是一个回调,因此从您的描述中不清楚您是如何使用它的。 @adamconkey 我已经用代码更新了它以重现。这相当简单。我只想腌制回调对象。根据错误,它一定与线程问题有关。 我找到的快速答案:Pickle chkpt_cb.best,然后将其重新分配给新的检查点。刚刚试了一下,效果很好。 【参考方案1】:

如果回调对象不被腌制(由于线程问题而不可取),我可以改为腌制:

best = chkpt_cb.best

这存储了回调见过的最好的监控指标,它是一个浮点数,你可以pickle并下次重新加载,然后这样做:

chkpt_cb.best = best   # if chkpt_cb is a brand new object you create when colab killed your session. 

这是我自己的设置:

# All paths should be on Google Drive, I omitted it here for simplicity.

chkpt_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model.epoch:02d-val_loss:.4f.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

if os.path.exists('chkpt_cb.best.pickle'):
  with open('chkpt_cb.best.pickle', 'rb') as f:
    best = pickle.load(f)
    chkpt_cb.best = best

def save_chkpt_cb():
  with open('chkpt_cb.best.pickle', 'wb') as f:
    pickle.dump(chkpt_cb.best, f, protocol=pickle.HIGHEST_PROTOCOL)

save_chkpt_cb_callback = tf.keras.callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: save_chkpt_cb()
)

history = model.fit_generator(generator=train_data_gen,
                          validation_data=dev_data_gen,
                          epochs=5,
                          callbacks=[chkpt_cb, save_chkpt_cb_callback])

因此,即使您的 colab 会话被终止,您仍然可以检索最近的最佳指标并将其告知您的新实例,并照常继续训练。当您重新编译有状态优化器并且可能导致损失/度量回归并且不想在前几个时期保存这些模型时,这尤其有用。

【讨论】:

【参考方案2】:

我认为您可能误解了 ModelCheckpoint 对象的预期用途。这是一个callback,在特定阶段的训练期间定期被调用。特别是 ModelCheckpoint 回调在每个 epoch 之后被调用(如果您保留默认的 period=1)并将您的模型以您指定的文件名保存到磁盘到 filepath 参数。模型的保存方式与here 中描述的相同。然后,如果您想稍后加载该模型,您可以执行类似的操作

from keras.models import load_model
model = load_model('my_model.h5')

关于 SO 的其他答案为从保存的模型继续训练提供了很好的指导和示例,例如:Loading a trained Keras model and continue training。重要的是,保存的 H5 文件存储了继续训练所需的模型的所有内容。

正如Keras documentation 中所建议的,您不应该使用pickle 来序列化您的模型。只需使用“fit”函数注册 ModelCheckpoint 回调:

chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.epoch:02d-val_loss:.4f.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)
model.fit(x_train, y_train,
          epochs=100,
          steps_per_epoch=5000,
          callbacks=[chkpt_cb])

您的模型将保存在一个 H5 文件中,以您拥有的名称命名,并自动为您格式化 epoch 数和损失值。例如,您保存的第 5 个 epoch 的损失为 0.0023 的文件看起来像 model.05-.0023.h5,并且由于您设置了 save_best_only=True,因此只有在您的损失比之前保存的更好时才会保存模型,这样您就不会污染你的目录有一堆不需要的模型文件。

【讨论】:

是的,我明白应该这样使用它。如果您使用过 colab 并在训练过程中被截断,您会发现如果您从头开始重新实例化回调,您最后的最佳指标将被遗忘。所以我试图找到回调对象可以保留在磁盘上的解决方案。如果您的笔记本会话是实时的,它肯定会在内存中。您可以运行多个 fit(...) 并且它仍然跟踪迄今为止最好的指标。 我找到了答案并发布了。最好的指标肯定存储在回调对象中。

以上是关于tf.keras 如何保存 ModelCheckPoint 对象的主要内容,如果未能解决你的问题,请参考以下文章

恢复培训 tf.keras Tensorboard

tf2.0 Keras:使用 RNN 的自定义张量流代码时无法保存权重

如何在 Tensorflow 中从 tf.keras 导入 keras?

如何将 tf.keras 与 bfloat16 一起使用

JSONDecodeError:加载 tf.Keras 模型时的期望值

如何使用 tf.keras 在 RNN 中应用层规范化?