在不同的会话中加载 TensorFlow 模型

Posted

技术标签:

【中文标题】在不同的会话中加载 TensorFlow 模型【英文标题】:Loading Tensorflow model in different session 【发布时间】:2018-03-31 11:49:53 【问题描述】:

我对这一切有点陌生,所以你能帮帮我吗?我试图找到这个问题的答案,但一无所获。

我正在尝试在一个单独的函数中在 python 中加载 Tensorflow 模型,这样我就可以在循环中使用该模型,而不必在 for 循环的每次迭代中加载它。

这是我现在的代码:

def load_network():
    prediction = neural_network_model(x)    
    return (prediction)



def use_neural_network(data, prediction):         
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.import_meta_graph(model_name+'.meta')
        saver.restore(sess,model_name)
        pred = sess.run(prediction, feed_dict=x: data)
        pred = np.asarray(pred)
    return pred


if __name__ == '__main__':
    result=[]
    Load= start_network()
    for i in data:
        result.append(use_neural_network(i,Load))

我想得到这样的东西:

def load_network():
    prediction = neural_network_model(x)    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.import_meta_graph(model_name+'.meta')
        saver.restore(sess,model_name)

    return (prediction)



def use_neural_network(data, prediction):         
    with tf.Session() as sess:
        pred = sess.run(prediction, feed_dict=x: data)
        pred = np.asarray(pred)
    return pred


if __name__ == '__main__':
    result=[]
    Load= start_network()
    for i in data:
        result.append(use_neural_network(i,Load))

【问题讨论】:

【参考方案1】:

通常,您想要实现的目标很容易实现,并且您走在正确的轨道上。在主块中,您有 start_network() 而不是第一行中的 load_network()。我还建议不要使用Load 作为变量名,但这应该不是问题。此外,TensorFlow Session(代码中的sess)应该是一个全局变量,或者您应该在主块或load_network() 函数中对其进行初始化,然后将其传递给use_neural_network() 函数。目前这两个函数中的两个sess变量的写法是局部的,因此指的是不同的会话。

如果您想避免必须使用 neural_network_model( x ) 函数,即在开始时构建模型,您可能希望冻结模型并以这种方式加载它,并嵌入架构也是。最容易遵循这方面的指南,例如 this one。

【讨论】:

以上是关于在不同的会话中加载 TensorFlow 模型的主要内容,如果未能解决你的问题,请参考以下文章

无法在 TensorFlow 2 中加载模型权重

在 python 中加载 Tensorflow Lite 模型

无法在仅 TensorFlow CPU 版本中加载模型

无法在 tensorflow 官方 resnet 模型中加载用于 eval 的图像

当我尝试在jetson tx1中加载卷积预训练模型时,tensorflow中的错误被杀死

无法在 Keras 2.1.0(使用 Tensorflow 1.3.0)中保存的 Keras 2.4.3(使用 Tensorflow 2.3.0)中加载 Keras 模型