在 tensorflow 中恢复图形失败,因为没有要保存的变量
Posted
技术标签:
【中文标题】在 tensorflow 中恢复图形失败,因为没有要保存的变量【英文标题】:Restoring graph in tensorflow fails because there is no variable to save 【发布时间】:2016-08-29 16:02:53 【问题描述】:我知道stack和github等上有无数关于如何在Tensorflow中恢复训练好的模型的问题。我读过其中的大部分(1,2,3)。
我的问题与 3 几乎完全相同,但是如果可能的话,我希望以不同的方式解决它,因为我的训练和测试需要在从 shell 调用的单独脚本中,我不想添加确切的我用来在测试脚本中定义图形的同一行,所以我不能使用 tensorflow FLAGS 和其他基于手动重新运行图形的答案。
我也不想 sess.run 每个变量并手动手动映射它们,因为我的图表很大(使用 import_graph_def 和参数 input_map)。
所以我运行一些图表并在特定脚本中对其进行训练。例如(但没有训练部分)
#Script 1
import tensorflow as tf
import cPickle as pickle
x=tf.Variable(42)
saver=tf.train.Saver()
sess=tf.Session()
#Saving the graph
graph_def=sess.graph_def
with open('graph.pkl','wb') as output:
pickle.dump(graph_def,output,HIGHEST_PROTOCOL)
#Training the model
sess.run(tf.initialize_all_variables())
#Saving the variables
saver.save(sess,"pretrained_model.ckpt")
我现在保存了图表和变量,因此即使我的图表中有额外的训练节点,我也应该能够从另一个脚本运行我的测试模型。
#Script 2
import tensorflow as tf
import cPickle as pickle
sess=tf.Session()
with open('graph.pkl','rb') as input:
graph_def=pickle.load(input)
tf.import_graph_def(graph_def,name='persisted')
那么显然我想使用保护程序恢复变量,但我遇到了与 3 相同的问题,因为没有找到要保存的变量甚至创建保护程序。所以我不能写:
saver=tf.train.Saver()
saver.restore(sess,"pretrained_model.ckpt")
有没有办法绕过这些限制?我认为通过导入图形可以恢复每个节点中未初始化的变量,但似乎不是。我真的需要像大多数给出的答案一样重新运行它吗?
【问题讨论】:
【参考方案1】:变量列表保存在Collection
中,而GraphDef
中没有保存。 Saver
默认使用ops.GraphKeys.VARIABLES
集合中的列表(可通过tf.all_variables()
访问),如果您从GraphDef
恢复而不是使用Python API 来构建模型,则该集合为空。您可以在tf.train.Saver(var_list=['MyVariable1:0', 'MyVariable2:0',...])
中手动指定变量列表。
您也可以使用MetaGraphDef
来代替GraphDef
来保存集合,最近添加了一个MetaGraphDef HowTo
【讨论】:
感谢我正在寻找的东西!很抱歉再问一个关于恢复模型的问题!但是我认为这对其他人也可能有用! @Yaroslav,我用过MetaGraphDef
,但我遇到了同样的问题。你能在这里看看我的问题吗:***.com/questions/47762114/…
链接失效了,我猜应该是this?【参考方案2】:
据我所知和我的测试,您不能简单地将名称传递给tf.train.Saver
对象。它必须是变量列表或字典。
我还想从 graph_def 读取模型,然后使用保护程序加载变量 - 但是尝试它只会导致错误消息:“要保存的变量不是变量”
【讨论】:
以上是关于在 tensorflow 中恢复图形失败,因为没有要保存的变量的主要内容,如果未能解决你的问题,请参考以下文章