无法在 TensorFlow 2 中加载模型权重
Posted
技术标签:
【中文标题】无法在 TensorFlow 2 中加载模型权重【英文标题】:Cannot load model weights in TensorFlow 2 【发布时间】:2020-06-16 15:04:38 【问题描述】:在 TensorFlow 2.2 中保存模型权重后,我无法加载它们。权重似乎保存正确(我认为),但是我无法加载预训练模型。
我当前的代码是:
segmentor = sequential_model_1()
discriminator = sequential_model_2()
def save_model(ckp_dir):
# create directory, if it does not exist:
utils.safe_mkdir(ckp_dir)
# save weights
segmentor.save_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'))
discriminator.save_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'))
def load_pretrained_model(ckp_dir):
try:
segmentor.load_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'), skip_mismatch=True)
discriminator.load_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'), skip_mismatch=True)
print('Loading pre-trained model from: 0'.format(ckp_dir))
except ValueError:
print('No pre-trained model available.')
然后我有训练循环:
# training loop:
for epoch in range(num_epochs):
for image, label in dataset:
train_step()
# save best model I find during training:
if this_is_the_best_model_on_validation_set():
save_model(ckp_dir='logs_dir')
然后,在“for 循环”训练结束时,我想加载最佳模型并对其进行测试。因此,我运行:
# load saved model and do a test:
load_pretrained_model(ckp_dir='logs_dir')
test()
但是,这会导致ValueError
。我检查了应该保存权重的目录,它们就在那里!
知道我的代码有什么问题吗?我是否错误地加载了重量?
谢谢!
【问题讨论】:
你能发布你得到的完整错误吗?this_is_the_best_model_on_validation_set
曾经评估过真实吗?文件真的存在吗?
@craymichael 感谢您的帮助 :) 是的。我不知道为什么我无法加载它们
【参考方案1】:
好的,这是您的问题 - 您拥有的 try-except
块掩盖了真正的问题。删除它会得到ValueError
:
ValueError: When calling model.load_weights, skip_mismatch can only be set to True when by_name is True.
有两种方法可以缓解这种情况 - 您可以使用by_name=True
调用load_weights
,或者根据您的需要删除skip_mismatch=True
。在测试您的代码时,这两种情况都适用于我。
另一个考虑因素是,当您将鉴别器和分段器检查点都存储到日志目录时,您每次都会覆盖 checkpoint
文件。这包含两个字符串,它们给出了特定模型检查点文件的路径。由于您第二次保存鉴别器,因此每次该文件都会说鉴别器而不参考分段器。您可以通过将每个模型存储在日志目录的两个子目录中来缓解这种情况,即
logs_dir/
+ discriminator/
+ checkpoint
+ ...
+ segmentor/
+ checkpoint
+ ...
尽管在当前状态下,您的代码可以在这种情况下工作。
【讨论】:
太棒了!这解决了我的问题。但是,现在我打印出以下工作:WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
你知道它是关于什么的吗?
呵呵,我以前没见过。看起来在 TF Keras 加载/保存教程中他们也收到了该消息,但他们没有解决它。还有一个尚未解决的问题here。恢复会改变结果吗?否则我认为你可以忽略它
我认为它不会影响我。不过,我觉得很奇怪,他们没有在官方教程中讨论它。谢谢你的帮助! :)以上是关于无法在 TensorFlow 2 中加载模型权重的主要内容,如果未能解决你的问题,请参考以下文章
无法在 Keras 2.1.0(使用 Tensorflow 1.3.0)中保存的 Keras 2.4.3(使用 Tensorflow 2.3.0)中加载 Keras 模型
无法在 tensorflow 官方 resnet 模型中加载用于 eval 的图像
在 Django 应用程序中加载 TensorFlow 模型的位置