在不同的会话中加载 TensorFlow 模型
Posted
技术标签:
【中文标题】在不同的会话中加载 TensorFlow 模型【英文标题】:Loading Tensorflow model in different session 【发布时间】:2018-03-31 11:49:53 【问题描述】:我对这一切有点陌生,所以你能帮帮我吗?我试图找到这个问题的答案,但一无所获。
我正在尝试在一个单独的函数中在 python 中加载 Tensorflow 模型,这样我就可以在循环中使用该模型,而不必在 for 循环的每次迭代中加载它。
这是我现在的代码:
def load_network():
prediction = neural_network_model(x)
return (prediction)
def use_neural_network(data, prediction):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(model_name+'.meta')
saver.restore(sess,model_name)
pred = sess.run(prediction, feed_dict=x: data)
pred = np.asarray(pred)
return pred
if __name__ == '__main__':
result=[]
Load= start_network()
for i in data:
result.append(use_neural_network(i,Load))
我想得到这样的东西:
def load_network():
prediction = neural_network_model(x)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(model_name+'.meta')
saver.restore(sess,model_name)
return (prediction)
def use_neural_network(data, prediction):
with tf.Session() as sess:
pred = sess.run(prediction, feed_dict=x: data)
pred = np.asarray(pred)
return pred
if __name__ == '__main__':
result=[]
Load= start_network()
for i in data:
result.append(use_neural_network(i,Load))
【问题讨论】:
【参考方案1】:通常,您想要实现的目标很容易实现,并且您走在正确的轨道上。在主块中,您有 start_network()
而不是第一行中的 load_network()
。我还建议不要使用Load
作为变量名,但这应该不是问题。此外,TensorFlow Session(代码中的sess
)应该是一个全局变量,或者您应该在主块或load_network()
函数中对其进行初始化,然后将其传递给use_neural_network()
函数。目前这两个函数中的两个sess
变量的写法是局部的,因此指的是不同的会话。
如果您想避免必须使用 neural_network_model( x )
函数,即在开始时构建模型,您可能希望冻结模型并以这种方式加载它,并嵌入架构也是。最容易遵循这方面的指南,例如 this one。
【讨论】:
以上是关于在不同的会话中加载 TensorFlow 模型的主要内容,如果未能解决你的问题,请参考以下文章
在 python 中加载 Tensorflow Lite 模型
无法在 tensorflow 官方 resnet 模型中加载用于 eval 的图像
当我尝试在jetson tx1中加载卷积预训练模型时,tensorflow中的错误被杀死
无法在 Keras 2.1.0(使用 Tensorflow 1.3.0)中保存的 Keras 2.4.3(使用 Tensorflow 2.3.0)中加载 Keras 模型