Streamlit 缓存 Keras 训练模型
Posted
技术标签:
【中文标题】Streamlit 缓存 Keras 训练模型【英文标题】:Streamlit cache Keras trained model 【发布时间】:2020-07-16 17:13:46 【问题描述】:我已经训练了一个模型(通过 Keras 框架),用 model.save('model.hdf5')
导出它,现在我想将它与很棒的 Streamlit 集成。
显然,我不想在最终用户每次插入新输入时都加载模型,而是一劳永逸地加载它。
所以我的代码看起来像这样:
@st.cache
def load_my_model():
model = load_model('model.hdf5')
model.summary()
return model
if __name__ == '__main__':
st.title('My first app')
sentence = st.text_input('Input your sentence here:')
model = load_my_model()
if sentence:
y_hat = model.predict(sentence)
这样我得到了:
"streamlit.errors.UnhashableType:"
异常。
我尝试使用@st.cache(allow_output_mutation=True)
,当我在streamlit 页面上运行查询时。我得到了:
“TypeError:无法将 feed_dict 键解释为 Tensor:Tensor Tensor("input_1:0", shape=(?, 80), dtype=int32) 不是此图的元素。”
(当然,没有任何缓存装饰器,模型已加载并且工作正常)
我应该如何正确加载和缓存经过 Keras 训练的模型?
Python 版本:2.7(很遗憾) Keras 版本:2.1.3 Tensorflow 版本:1.3.0 Streamlit 版本:0.55.2非常感谢!
【问题讨论】:
【参考方案1】:解决办法是:
-
添加
_make_predict_function()
调用
返回会话
from keras import backend as K
@st.cache(allow_output_mutation=True)
def load_model():
model = load_model(MODEL_PATH)
model._make_predict_function()
model.summary() # included to make it visible when model is reloaded
session = K.get_session()
return model, session
if __name__ == '__main__':
st.title('My first app')
sentence = st.text_input('Input your sentence here:')
model, session = load_model()
if sentence:
K.set_session(session)
y_hat = model.predict(sentence)
【讨论】:
以上是关于Streamlit 缓存 Keras 训练模型的主要内容,如果未能解决你的问题,请参考以下文章