Tensorflow——keras model.save() raise NotImplementedError
Posted
技术标签:
【中文标题】Tensorflow——keras model.save() raise NotImplementedError【英文标题】: 【发布时间】:2019-03-04 08:22:12 【问题描述】:import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128,activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128,activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10,activation=tf.nn.softmax))
model.compile(optimizer ='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=3)
当我试图保存模型时
model.save('epic_num_reader.model')
我得到一个 NotImplementedError:
NotImplementedError Traceback (most recent call last)
<ipython-input-4-99efa4bdc06e> in <module>()
1
----> 2 model.save('epic_num_reader.model')
NotImplementedError: Currently `save` requires model to be a graph network. Consider using `save_weights`, in order to save the weights of the model.
那么如何保存代码中定义的模型呢?
【问题讨论】:
啊,我看到你也在尝试pythonprogramming.net/… 的教程。很高兴看到您已经找到了解决方案 【参考方案1】:你忘记了第一层定义中的input_shape
参数,导致模型未定义,保存未定义模型尚未实现,触发错误。
model.add(tf.keras.layers.Flatten(input_shape = (my, input, shape)))
只需将input_shape
添加到第一层,它应该可以正常工作。
【讨论】:
这是一个很好的答案。谢谢。 我遵循相同的 senddex 教程并将代码更改为model.add(tf.keras.layers.Flatten(input_shape=x_train.shape))
,但运行 model.fit()
时得到了 expected flatten_1_input to have 4 dimensions, but got array with shape (60000, 28, 28)
。代码到底应该是什么?
@hanaZ 输入形状不应该包含batch/samples维度,以后记住cmet是为了澄清问题和答案,而不是介绍你自己的。
@Matias。非常感谢您的回复。我试过 model.add(tf.keras.layers.Flatten(input_shape=x_train[0].shape))
,然后它工作了。
它可以正常工作,但我发现我的模型的性能大幅下降并且asked a question about it(目前无法解释)【参考方案2】:
如果按照 Matias 的建议还是没有解决问题,可以考虑使用tf.keras.models.save_model()
和load_model()
。就我而言,它奏效了。
【讨论】:
【参考方案3】:tf.keras.models.save_model
在此处工作(tensorflow 1.12.0)(即使未指定 input_shape)
【讨论】:
【参考方案4】:错误原因:
我遇到了同样的错误并尝试了上述答案,但得到了错误。但我找到了解决问题的方法,我将在下面分享:
在定义模型的输入层时检查是否传递了input_shape,如果没有,在保存和加载模型时会报错。
如何定义input_shape?
让我们考虑一个例子,如果你使用 minst 数据集:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
它由手写数字 0-9 的图像组成,每个图像的分辨率为 28 x 28。 为此,我们可以将输入形状定义为 (28,28),而无需提及批量大小,如下所示:
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
通过这种方式,您可以通过查看输入训练数据集来给出输入形状。
保存训练好的模型:
现在在训练和测试模型之后,我们可以保存我们的模型。以下代码对我有用,它在重新加载模型后也没有改变准确性:
使用 save_model()
import tensorflow as tf
tf.keras.models.save_model(
model,
"your_trained_model.model",
overwrite=True,
include_optimizer=True
)
使用 .save()
your_trained_model.save('your_trained_model.model')
del model # deletes the existing model
现在加载我们保存的模型:
model2 = tf.keras.models.load_model("your_trained_model.model")
更多详情请参考此链接:Keras input explanation: input_shape, units, batch_size, dim, etc
【讨论】:
【参考方案5】:<!-- Success, please check -->
import tensorflow as tf
import matplotlib.pyplot as plt
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
plt.imshow(x_train[0], cmap=plt.cm.binary)
x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)
plt.imshow(x_train[0], cmap=plt.cm.binary)
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=x_train[0].shape))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=3)
val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)
model.save('epic_num_reader.model')
【讨论】:
以上是关于Tensorflow——keras model.save() raise NotImplementedError的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow+Keras用Tensorflow.keras的方法替代keras.layers.merge
keras与tensorflow.python.keras - 使用哪一个?
keras 与 tensorflow.python.keras - 使用哪一个?
无法在 Keras 2.1.0(使用 Tensorflow 1.3.0)中保存的 Keras 2.4.3(使用 Tensorflow 2.3.0)中加载 Keras 模型