keras的模型保存与加载
Posted 月疯
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了keras的模型保存与加载相关的知识,希望对你有一定的参考价值。
三种方式保存模型:
save/load weights
最简单的保存,只保存模型的参数save/load entire model
会保存模型的所有参数saved_model
用户最后得到模型就可以直接部署,不需要直接把代码给了用户,python训练的模型,可以个了c++完成一个工厂的部署
第一种简单方法:
#保存模型
model.save_weights('./checkpoints/my_checkpoint')
#加载模型
model =create_model()
model.load_weights('./checkpoints/my_checkpoint')
#评估模型
loss,acc = model.evaluate(test_images,test_labels)
#查看保存之前的acc和保存之后的基本没啥差别
第二种还原整个模型方法:
network.save('model.h5')
print('saved total model.')
del network
print('load model from file')
network = tf.keras.models.load_model('model.h5')
network.evaluate(x_val,y_val)
第三种模型保存,工业环境部署
#保存模型
tf.saved_model.save(m,'tmp/saved_model/')
#导出模型
imported = tf.saved_model.load(path)
f = imported.signatures['serving_default']
完整列子:
import numpy as np
from keras.datasets import mnist #直接从keras里面应用数据集
from keras.utils import np_utils #keras 里面用到的一个 np 的工具包
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD #优化函数;
from keras.models import load_model
#载入数据
(x_train,y_train),(x_test,y_test)=mnist.load_data() #分为测试集和训练集
# (6000,28,28) -> (6000,784)
x_train=x_train.reshape(x_train.shape[0],-1)/255.0 #-1表示是自动判断,/225是表示归一化。
x_test=x_test.reshape(x_test.shape[0],784)/255.0#行数是 x_train.shape[0]行。
#标签转换成 one hot 格式
y_train=np_utils.to_categorical(y_train,num_classes=10)#专门用来转格式的包
y_test=np_utils.to_categorical(y_test,num_classes=10)
#载入模型
model=load_model("model.h5")
#载入的模型是已经训练好的,直接进行评估就行。
loss,accuracy=model.evaluate(x_test,y_test)
print(loss)
print(accuracy)
#把模型转成json格式,打印出来看一下;
json_string=model.to_json()
print(json_string)
#对于载入的模型,可以继续进行训练;
model.fit(x_train,y_train,batch_size=64,epochs=2)
loss,accuracy=model.evaluate(x_test,y_test)
print(loss)
print(accuracy)
summary()计算模型参数
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
model = Sequential() # 顺序模型
# 输入层
model.add(Dense(7, input_shape=(4,))) # Dense就是常用的全连接层
model.add(Activation('sigmoid')) # 激活函数
# 隐层
model.add(Dense(13)) # Dense就是常用的全连接层
model.add(Activation('sigmoid')) # 激活函数
# 输出层
model.add(Dense(5))
model.add(Activation('softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=["accuracy"])
model.summary()
输出:
以上是关于keras的模型保存与加载的主要内容,如果未能解决你的问题,请参考以下文章
从不同版本的 tf.keras 加载保存的模型(从 tf 2.3.0 到 tf 1.12)