在 tensorflow 中恢复图形失败,因为没有要保存的变量

Posted

技术标签:

【中文标题】在 tensorflow 中恢复图形失败,因为没有要保存的变量【英文标题】:Restoring graph in tensorflow fails because there is no variable to save 【发布时间】:2016-08-29 16:02:53 【问题描述】:

我知道stack和github等上有无数关于如何在Tensorflow中恢复训练好的模型的问题。我读过其中的大部分(1,2,3)。

我的问题与 3 几乎完全相同,但是如果可能的话,我希望以不同的方式解决它,因为我的训练和测试需要在从 shell 调用的单独脚本中,我不想添加确切的我用来在测试脚本中定义图形的同一行,所以我不能使用 tensorflow FLAGS 和其他基于手动重新运行图形的答案。

我也不想 sess.run 每个变量并手动手动映射它们,因为我的图表很大(使用 import_graph_def 和参数 input_map)。

所以我运行一些图表并在特定脚本中对其进行训练。例如(但没有训练部分)

#Script 1
import tensorflow as tf
import cPickle as pickle

x=tf.Variable(42)
saver=tf.train.Saver()
sess=tf.Session()
#Saving the graph
graph_def=sess.graph_def
with open('graph.pkl','wb') as output:
  pickle.dump(graph_def,output,HIGHEST_PROTOCOL)


#Training the model
sess.run(tf.initialize_all_variables())
#Saving the variables
saver.save(sess,"pretrained_model.ckpt")

我现在保存了图表和变量,因此即使我的图表中有额外的训练节点,我也应该能够从另一个脚本运行我的测试模型。

#Script 2
import tensorflow as tf
import cPickle as pickle

sess=tf.Session()
with open('graph.pkl','rb') as input:
  graph_def=pickle.load(input)


tf.import_graph_def(graph_def,name='persisted')

那么显然我想使用保护程序恢复变量,但我遇到了与 3 相同的问题,因为没有找到要保存的变量甚至创建保护程序。所以我不能写:

saver=tf.train.Saver()
saver.restore(sess,"pretrained_model.ckpt")

有没有办法绕过这些限制?我认为通过导入图形可以恢复每个节点中未初始化的变量,但似乎不是。我真的需要像大多数给出的答案一样重新运行它吗?

【问题讨论】:

【参考方案1】:

变量列表保存在Collection 中,而GraphDef 中没有保存。 Saver 默认使用ops.GraphKeys.VARIABLES 集合中的列表(可通过tf.all_variables() 访问),如果您从GraphDef 恢复而不是使用Python API 来构建模型,则该集合为空。您可以在tf.train.Saver(var_list=['MyVariable1:0', 'MyVariable2:0',...]) 中手动指定变量列表。

您也可以使用MetaGraphDef 来代替GraphDef 来保存集合,最近添加了一个MetaGraphDef HowTo

【讨论】:

感谢我正在寻找的东西!很抱歉再问一个关于恢复模型的问题!但是我认为这对其他人也可能有用! @Yaroslav,我用过MetaGraphDef,但我遇到了同样的问题。你能在这里看看我的问题吗:***.com/questions/47762114/… 链接失效了,我猜应该是this?【参考方案2】:

据我所知和我的测试,您不能简单地将名称传递给tf.train.Saver 对象。它必须是变量列表或字典。

我还想从 graph_def 读取模型,然后使用保护程序加载变量 - 但是尝试它只会导致错误消息:“要保存的变量不是变量”

【讨论】:

以上是关于在 tensorflow 中恢复图形失败,因为没有要保存的变量的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow:恢复图形和模型,然后在单个图像上运行评估

TensorFlow 从文件中保存/加载图形

如何在 TensorFlow 中调试 NaN 值?

tensorflow的断点续训

Tensorflow:如何使用恢复的模型?

TensorFlow 模型恢复(恢复训练似乎从头开始)