TensorFlow 在忽略范围名称或进入新范围名称时恢复
Posted
技术标签:
【中文标题】TensorFlow 在忽略范围名称或进入新范围名称时恢复【英文标题】:Tensorflow restore while ignoring scope name or into new scope name 【发布时间】:2018-08-05 01:14:25 【问题描述】:我首先训练了网络N
,并使用保护程序将其保存到检查点Checkpoint_N
。在N
中定义了一些变量范围。
现在,我想使用这个 trained 网络 N
构建一个连体网络,如下所示:
with tf.variable_scope('siameseN',reuse=False) as scope:
networkN = N()
embedding_1 = networkN.buildN()
# this defines the network graph and all the variables.
tf.train.Saver().restore(session_variable,Checkpoint_N)
scope.reuse_variables()
embedding_2 = networkN.buildN()
# define 2nd branch of the Siamese, by reusing previously restored variables.
当我执行上述操作时,restore 语句会抛出一个Key Error
,即在N
图表中的每个变量的检查点文件中找不到siameseN/conv1
。
有没有办法做到这一点,而无需更改N
的代码?我基本上只是为N
中的每个变量和操作添加了一个父作用域。我可以通过告诉 tensorflow 忽略父范围之类的东西来将权重恢复到正确的变量吗?
【问题讨论】:
【参考方案1】:这与:How to restore weights with different names but same shapes Tensorflow?
tf.train.Saver(var_list='variable_name_in_checkpoint':var_to_be_restored_to,...')
可以获取要恢复的变量列表或字典
(e.g. 'variable_name_in_checkpoint':var_to_be_restored_to,...)
您可以通过遍历当前会话变量中的所有变量来准备上述字典,并将会话变量用作值并获取当前变量的名称,并从变量名称中删除'siameseN/'并用作键。理论上应该可以的。
【讨论】:
【参考方案2】:我不得不稍微修改一下代码,来编写我自己的恢复函数。我决定将检查点文件作为字典加载,变量名作为键,对应的 numpy 数组作为值,如下所示:
checkpoint_path = '/path/to/checkpoint'
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
key_to_numpy =
for key in var_to_shape_map:
key_to_numpy[key] = reader.get_tensor(key)
我已经有了这个创建所有变量的函数,它是从图形N
中以所需名称调用的。我修改它以使用从字典查找中获得的 numpy 数组初始化变量。而且,为了使查找成功,我只是剥离了我添加的父名称范围,如下所示:
init = tf.constant(key_to_numpy[ name.split('siameseN/')[1] ])
var = tf.get_variable(name, initializer=init)
#var = tf.get_variable(name, shape, initializer=initializer)
return var
这是一种更老套的方法。我没有使用@edit 的答案,因为我已经编写了上面的代码。此外,我所有的权重都是在一个函数中创建的,该函数将这些权重分配给变量var
并返回它。因为这类似于函数式编程,所以变量var
不断被覆盖。 var
永远不会暴露于更高级别的功能。要使用@edit 的答案,我必须为每个初始化使用不同的张量变量名称,并以某种方式将它们公开给更高级别的函数,以便保护程序可以在他们的答案中将它们用作var_to_be_restored_to
。
但@edit 的解决方案是不那么老套的解决方案,因为它遵循记录在案的用法。所以我会接受这个答案。我所做的可能是另一种解决方案。
【讨论】:
以上是关于TensorFlow 在忽略范围名称或进入新范围名称时恢复的主要内容,如果未能解决你的问题,请参考以下文章