Streamlit 缓存 Keras 训练模型

Posted

技术标签:

【中文标题】Streamlit 缓存 Keras 训练模型【英文标题】:Streamlit cache Keras trained model 【发布时间】:2020-07-16 17:13:46 【问题描述】:

我已经训练了一个模型(通过 Keras 框架),用 model.save('model.hdf5') 导出它,现在我想将它与很棒的 Streamlit 集成。 显然,我不想在最终用户每次插入新输入时都加载模型,而是一劳永逸地加载它。 所以我的代码看起来像这样:

@st.cache
def load_my_model():
    model = load_model('model.hdf5')
    model.summary()

    return model

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model = load_my_model()
    if sentence:
        y_hat = model.predict(sentence)

这样我得到了:

"streamlit.errors.UnhashableType:"

异常。 我尝试使用@st.cache(allow_output_mutation=True),当我在streamlit 页面上运行查询时。我得到了:

“TypeError:无法将 feed_dict 键解释为 Tensor:Tensor Tensor("input_1:0", shape=(?, 80), dtype=int32) 不是此图的元素。”

(当然,没有任何缓存装饰器,模型已加载并且工作正常)

我应该如何正确加载和缓存经过 Keras 训练的模型?

Python 版本:2.7(很遗憾) Keras 版本:2.1.3 Tensorflow 版本:1.3.0 Streamlit 版本:0.55.2

非常感谢!

【问题讨论】:

【参考方案1】:

解决办法是:

    添加_make_predict_function()调用 返回会话
from keras import backend as K

@st.cache(allow_output_mutation=True)
def load_model():
    model = load_model(MODEL_PATH)
    model._make_predict_function()
    model.summary()  # included to make it visible when model is reloaded
    session = K.get_session()
    return model, session

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model, session = load_model()
    if sentence:
        K.set_session(session)
        y_hat = model.predict(sentence)

【讨论】:

以上是关于Streamlit 缓存 Keras 训练模型的主要内容,如果未能解决你的问题,请参考以下文章

Macbook 上的 Streamlit

Keras:使用更大的训练集更新模型

从经过训练的 keras 模型中获取训练超参数

如何找到训练 keras 模型的 epoch 数?

keras结合了预训练模型

优化用于 Keras 模型训练的 GPU 使用