如何将 keras 模型与其他数据一起保存并完全加载?
Posted
技术标签:
【中文标题】如何将 keras 模型与其他数据一起保存并完全加载?【英文标题】:How to save keras model along with other data and load them altogether? 【发布时间】:2018-08-08 21:13:07 【问题描述】:我有以下代码来训练一个 keras 神经网络
from keras import Sequential
from keras.layers import Dense
from keras.models import load_model
import numpy as np
class Model:
def __init__(self, data=None):
self.data = data
self.metrics = []
self.model = self.__build_model()
def __build_model(self):
model = Sequential()
model.add(Dense(4, activation='relu', input_shape=(3,)))
model.add(Dense(1, activation='relu'))
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])
return model
def train(self, epochs):
self.model.fit(self.data[:, :-1], self.data[:,-1], epochs=epochs)
return self
def test(self, data):
self.metrics = self.model.evaluate(data[:, :-1], data[:, -1])
return self
def predict(self, input):
return self.model.predict(input)
def save(self, path):
self.model.save(path)
# I would like to save self.metrics at the same time
def load(self, path):
self.model = load_model(path)
if __name__ == '__main__':
train_data = np.random.rand(1000, 4)
test_data = np.random.rand(100, 4)
print("TRAINING, TESTING & SAVING..")
model = Model(train_data)\
.train(epochs=5)\
.test(test_data)\
.save('./model.h5')
print('LOADING model & PREDICTING..')
test_sample = np.random.rand(1, 3)
model = Model()
model.load('./model.h5')
# I can then do like:
test_output = model.predict(test_sample)
print(test_output)
# And want to get metrics which i had saved with it like:
metrics = model.metrics
print(metrics)
如您所见,它将模型保存到 h5 文件中,但只有 keras 模型没有其他任何内容。 如何像指标一样同时保存其他数据,然后在加载 keras 模型时也能够加载它们。
谢谢!
【问题讨论】:
【参考方案1】:您可以使用任何序列化框架来执行此操作。
import hickle
def save(self, path):
self.model.save(path)
hkl.dump(self.metrics, 'metrics.hkl', mode='w')
def load(self, path):
self.model = load_model(path)
self.metrics = hkl.load('metrics.hkl')
您也可以将其保存为单个文件,只需从指标和模型对象中创建一个列表或另一个对象。我建议将它们分开保存。
【讨论】:
以上是关于如何将 keras 模型与其他数据一起保存并完全加载?的主要内容,如果未能解决你的问题,请参考以下文章
如何将保存的模型转换或加载到 TensorFlow 或 Keras?
keras 如何保存训练集与验证集正确率的差最小那次epoch的网络及权重
如何将 GeoFire 坐标与 Firebase 数据库中的其他项目一起保存?