Keras 模型在本地运行良好,但无法在 Flask API 上运行

Posted

技术标签:

【中文标题】Keras 模型在本地运行良好,但无法在 Flask API 上运行【英文标题】:Keras model working fine locally but won't work on Flask API 【发布时间】:2019-05-21 07:06:34 【问题描述】:

我正在使用不同的分类器解决这个心脏病检测问题。 我正在做的是将模型保存在 h5 文件中并创建它的对象并以 json 格式返回响应。

但相同的模型不适用于在我的终端上完美运行的烧瓶 api。

这是我的神经网络

def ANN():
    global x_train,x_test,y_train,y_test
    model = Sequential()

    #implicit input layer combined with hidden layer
    model.add(Dense(units = 13, kernel_initializer = 'uniform', activation = 'relu', input_dim = 13))

    #hidden layer 2
    model.add(Dense(units = 13, kernel_initializer = 'uniform', activation = 'relu', input_dim = 13))

    #output layer
    model.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid'))

    model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])

    #fitting with optimal hyperparameters
    model.fit(x_train, y_train, batch_size = 25, nb_epoch = 287)

    return 'model':model,
            'accuracy':accuracy_score(model.predict(x_test) > 0.5,y_test)*100

.h5文件中保存模型后,在我的flask api中,

ann = load_model('ann8524.h5')
print(ann.predict(x_test)) #test set, for just checking.

这是错误信息:

* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit) [2018-12-20 23:37:43,548] 
ERROR in app: Exception on /heart/predict

[GET] Traceback (most recent call last):   
File "C:\python_installation\lib\site-packages\flask\app.py", line 1813, in full_dispatch_request
    rv = self.dispatch_request()   
File "C:\python_installation\lib\site-packages\flask\app.py", line 1799, in dispatch_request
    return self.view_functions[rule.endpoint](**req.view_args)   
File "C:\python_installation\lib\site-packages\flask_restful\__init__.py", line 458, in wrapper
    resp = resource(*args, **kwargs)   
File "C:\python_installation\lib\site-packages\flask\views.py", line 88, in view
    return self.dispatch_request(*args, **kwargs)   
File "C:\python_installation\lib\site-packages\flask_restful\__init__.py", line 573, in dispatch_request
    resp = meth(*args, **kwargs)   
File "app.py", line 41, in get
    print(ann.predict(x_test))   
File "C:\python_installation\lib\site-packages\keras\engine\training.py", line 1164, in predict
    self._make_predict_function()   
File "C:\python_installation\lib\site-packages\keras\engine\training.py", line 554, in _make_predict_function
    **kwargs)   
File "C:\python_installation\lib\site-packages\keras\backend\tensorflow_backend.py", line 2744, in function
    return Function(inputs, outputs, updates=updates, **kwargs)   
File "C:\python_installation\lib\site-packages\keras\backend\tensorflow_backend.py", line 2546, in __init__
    with tf.control_dependencies(self.outputs):   
File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 5004, in control_dependencies
    return get_default_graph().control_dependencies(control_inputs)   
File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 4543, in control_dependencies
    c = self.as_graph_element(c)   
File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 3490, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)   
File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 3569, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj) 

ValueError: Tensor Tensor("dense_3/Sigmoid:0", shape=(?, 1), dtype=float32) is not an element of this graph.

127.0.0.1 - - [20/Dec/2018 23:37:43] "[1m[35mGET /heart/predict HTTP/1.1[0m" 500 -

但它在 Spyder 中运行良好。 (完全相同的代码)

【问题讨论】:

您可以在加载模型后尝试打印模型摘要而不进行预测。只是试图缩小与预测或加载已保存模型有关的问题。 【参考方案1】:

您需要从 Tensorflow 获取默认图表,按照以下步骤应该可以解决此问题:

import tensorflow as tf
ann = load_model('ann8524.h5')
graph = tf.get_default_graph()

def your_handler():
    global graph
    with graph.as_default():
        print(ann.predict(x_test))

【讨论】:

有人能解释一下为什么这是必要的吗?

以上是关于Keras 模型在本地运行良好,但无法在 Flask API 上运行的主要内容,如果未能解决你的问题,请参考以下文章

我在 google colab 上训练了一个 keras 模型。现在无法在我的系统上本地加载它。

Azure ML Studio 环境中的 Python 自定义模型错误 0085,在本地环境中运行良好

无法使用 Plaidml 在 GPU 上运行 Keras 模型

视觉工作室正在运行,但本地主机无法连接

用于删除或存档旧日志文件的 PowerShell 脚本在本地运行良好,但在远程位置无法运行

带有张量流的 keras 运行良好,直到我添加回调