Keras:如何保存模型或权重?
Posted
技术标签:
【中文标题】Keras:如何保存模型或权重?【英文标题】:Keras: How to save models or weights? 【发布时间】:2019-07-22 19:59:38 【问题描述】:如果这个问题看起来很简单,我很抱歉。但是阅读 Keras 保存和恢复帮助页面:
https://www.tensorflow.org/beta/tutorials/keras/save_and_restore_models
我不明白如何在训练期间使用“ModelCheckpoint”进行保存。帮助文件提到它应该提供 3 个文件,我只看到一个,MODEL.ckpt。
这是我的代码:
checkpoint_dir = FolderName + "/tmp/model.ckpt"
cp_callback = k.callbacks.ModelCheckpoint(checkpoint_dir,verbose=1,save_weights_only=True)
parallel_model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),loss=my_cost_MSE, metrics=['accuracy])
parallel _model.fit(image, annotation, epochs=epoch,
batch_size=batch_size, steps_per_epoch=10,
validation_data=(image_val,annotation_val),validation_steps=num_batch_val,callbacks=callbacks_list)
另外,当我想在训练后加载权重时:
model = k.models.load_model(file_checkpoint)
我得到错误:
"raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
ValueError: Unknown loss function:my_cost_MSE"
my-cost_MSE 是我在训练中使用的成本函数。
【问题讨论】:
【参考方案1】:首先,您似乎使用的是tf.keras
(来自 tensorflow)实现而不是keras
(来自 keras-team/keras 存储库)。在这种情况下,如tf.keras guide 中所述:
保存模型的权重时,tf.keras 默认为检查点 格式。传递 save_format='h5' 以使用 HDF5。
另一方面,请注意,添加回调 ModelCheckpoint
通常大致相当于在每个 epoch 结束时调用 model.save(...)
,因此您应该期望保存三个文件(根据 @ 987654322@).
它不这样做的原因是,通过使用选项save_weights_only=True
,您只保存了权重。大致相当于在每个纪元结束时替换对model.save
的调用以替换model.save_weights
。因此,唯一要保存的文件是带有权重的文件。
从这里,您可以通过两种不同的方式进行操作:
只存储权重
您需要预先加载模型(比如说结构),然后调用 model.load_weights
而不是 keras.models.load_model
:
model = MyModel(...) # Your model definition as used in training
model.load_weights(file_checkpoint)
请注意,在这种情况下,您不会遇到自定义定义 (my_cost_MSE
) 的问题,因为您只是在加载模型权重。
存储整个模型
另一种方法是存储整个模型并相应地加载它:
cp_callback = k.callbacks.ModelCheckpoint(
checkpoint_dir,verbose=1,
save_weights_only=False
)
parallel_model.compile(
optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
loss=my_cost_MSE,
metrics=['accuracy']
)
model.fit(..., callbacks=[cp_callback])
然后你可以通过以下方式加载它:
model = k.models.load_model(file_checkpoint, custom_objects="my_cost_MSE": my_cost_MSE)
请注意,在后一种情况下,您需要指定 custom_objects
,因为反序列化模型需要其定义。
【讨论】:
【参考方案2】:keras
有一个save
命令。它保存了重建模型所需的所有细节。
(来自keras docs)
from keras.models import load_model
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
del model # deletes the existing model
# returns am identical compiled model
model = load_model('my_model.h5')
【讨论】:
如何在训练期间应用它?我在其他页面上看到这个询问保存,但在 model.fit 中它不起作用。以上是关于Keras:如何保存模型或权重?的主要内容,如果未能解决你的问题,请参考以下文章
keras 如何保存训练集与验证集正确率的差最小那次epoch的网络及权重