Keras model.save() 和 model.save_weights() 的区别?

Posted

技术标签:

【中文标题】Keras model.save() 和 model.save_weights() 的区别?【英文标题】:Difference between Keras model.save() and model.save_weights()? 【发布时间】:2017-07-26 02:25:14 【问题描述】:

在 Keras 中保存模型,输出文件有什么区别:

    model.save() model.save_weights() 回调中的ModelCheckpoint()

model.save() 中保存的文件比model.save_weights() 中的模型大,但比 JSON 或 Yaml 模型架构文件大得多。为什么是这样?

重申一下:为什么 size(model.save()) + size(something) = size(model.save_weights()) + size(model.to_json()),那是什么“东西”?

只使用model.save_weights()model.to_json() 并从中加载会比只使用model.save()load_model() 更有效吗?

有什么区别?

【问题讨论】:

【参考方案1】:

save() 将权重和模型结构保存到单个HDF5 文件中。我相信它还包括优化器状态之类的东西。然后,您可以使用带有load() 的 HDF5 文件来重建整个模型,包括权重。

save_weights() 仅将权重保存到 HDF5,仅此而已。您需要额外的代码来从 JSON 文件重建模型。

【讨论】:

只是想澄清一下,使用h5dump --contents我将保存的模型与保存的权重进行了比较,我可以看到权重只是模型 hd5 文件中的一个“组”。然而,也有优化器状态。但我没有看到任何与模型架构相关的文件。优化器状态是什么?模型架构如何持久化? @CMCDragonkai 这不是一个澄清,而是一个新问题。 @MatiasValdenegro 您愿意解释一下为什么要保存优化器的状态吗?在尝试继续训练相同的模型但在不同的会话中(例如,关闭 python 并在另一天继续训练)时,仅加载权重可能会出现什么问题。 @payne 优化器有状态,比如梯度的运行方式,所以如果从头开始,学习可能会不稳定甚至失败。 @MatiasValdenegro 所以使用model.save_weights('my_model_weights.h5') 会打乱学习过程,应该使用model.save('my_model.h5') 以便在中断的地方继续训练?【参考方案2】: model.save_weights() 只会保存权重,因此如果需要,您可以将它们应用到不同的架构上 mode.save() 将保存模型的架构 + 权重 + 训练配置 + 优化器的状态

【讨论】:

doc 此函数与 Keras Model 的 save_weights 函数略有不同。 from tf.keras.Model.save_weights 使用 filepath 中指定的名称创建一个检查点文件,而 tf.train.Checkpoint 为检查点编号,使用 filepath 作为检查点文件名的前缀。除此之外,model.save_weights() 和 tf.train.Checkpoint(model).save() 是等价的。 但你的答案不同。【参考方案3】:

只需添加 ModelCheckPoint 的输出,如果它与其他任何人相关:在模型训练期间用作回调,它可以保存整个模型或仅保存权重,具体取决于 save_weights_only 参数设置的状态。 TRUE 和权重只被保存,类似于调用model.save_weights()。 FALSE(默认)并保存整个模型,如调用model.save()

【讨论】:

【参考方案4】:

除了上述答案之外,从 tf.keras 版本“2.7.0”开始,可以使用 model.save() 以 2 种格式保存模型,即 TensorFlow SavedModel 格式和较旧的 Keras H5 格式。推荐的格式是 SavedModel,它是调用 model.save() 时的默认格式。要保存为 .h5(HDF5) 格式,请使用 model.save('my_model', save_format='h5') More

【讨论】:

以上是关于Keras model.save() 和 model.save_weights() 的区别?的主要内容,如果未能解决你的问题,请参考以下文章

Keras模型保存的几个方法和它们的区别

保存及读取keras模型参数

Tensorflow——keras model.save() raise NotImplementedError

Keras模型导入报错

将模型保存在keras中是不是有先决条件?

Keras - 没有停止和恢复训练的好方法?