tf2 模型保存总结
Posted xiexiaokui
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf2 模型保存总结相关的知识,希望对你有一定的参考价值。
tf2 模型保存总结
1. model.save保存的是所有信息,结果是单文件,最为简单。
实例:保 加
model_name = "./model_save/fassionMnist_save.h5"
model.save(model_name)
new_model = keras.models.load_model(model_name)
2. model.save_weights(weight_file)保存的是权重,结果是单文件。
weight_file="./model_save/weights.h5"
示例:保 创 编 加
model.save_weights(weight_file)
?
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28,28)))
model.add(keras.layers.Dense(128,activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))
model.summary()
?
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["acc"])
?
model.load_weights(weight_file)
3. 检查点保存权重,结果多文件
示例:
ckpt_path="./ckpt/model_ckpt.ckpt"
ckpt_callback=keras.callbacks.ModelCheckpoint(
ckpt_path,save_weights_only=True)
history = model.fit(train_image,train_label,epochs=3,callbacks=[ckpt_callback])
?
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28,28)))
model.add(keras.layers.Dense(128,activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))
model.summary()
?
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["acc"])
?
model.load_weights(ckpt_path)
?
4. 检查点保存全部模型,结果是文件夹
而且win下保存路径必须用 反斜杠,不能用正斜杠,可视为bug
model_ckpt_path=".ckptmodel3.model"
ckpt_callback=keras.callbacks.ModelCheckpoint(
model_ckpt_path,save_weights_only=False)
model.evaluate(test_image,test_label,verbose=0)
history = model.fit(train_image,train_label,epochs=3,callbacks=[ckpt_callback])
model.evaluate(test_image,test_label,verbose=0)
?
new_model = keras.models.load_model(model_ckpt_path)
new_model.evaluate(test_image,test_label,verbose=0)
以上是关于tf2 模型保存总结的主要内容,如果未能解决你的问题,请参考以下文章
TF2.0:翻译模型:恢复保存的模型时出错:检查点(根)中未解析的对象.optimizer.iter:属性