TensorFlow 在忽略范围名称或进入新范围名称时恢复

Posted

技术标签:

【中文标题】TensorFlow 在忽略范围名称或进入新范围名称时恢复【英文标题】:Tensorflow restore while ignoring scope name or into new scope name 【发布时间】:2018-08-05 01:14:25 【问题描述】:

我首先训练了网络N,并使用保护程序将其保存到检查点Checkpoint_N。在N 中定义了一些变量范围。

现在,我想使用这个 trained 网络 N 构建一个连体网络,如下所示:

with tf.variable_scope('siameseN',reuse=False) as scope:
  networkN = N()
  embedding_1 = networkN.buildN() 
  # this defines the network graph and all the variables.
  tf.train.Saver().restore(session_variable,Checkpoint_N)
  scope.reuse_variables()
  embedding_2 = networkN.buildN()
  # define 2nd branch of the Siamese, by reusing previously restored variables.

当我执行上述操作时,restore 语句会抛出一个Key Error,即在N 图表中的每个变量的检查点文件中找不到siameseN/conv1

有没有办法做到这一点,而无需更改N 的代码?我基本上只是为N 中的每个变量和操作添加了一个父作用域。我可以通过告诉 tensorflow 忽略父范围之类的东西来将权重恢复到正确的变量吗?

【问题讨论】:

【参考方案1】:

这与:How to restore weights with different names but same shapes Tensorflow?

tf.train.Saver(var_list='variable_name_in_checkpoint':var_to_be_restored_to,...')

可以获取要恢复的变量列表或字典

(e.g. 'variable_name_in_checkpoint':var_to_be_restored_to,...)

您可以通过遍历当前会话变量中的所有变量来准备上述字典,并将会话变量用作值并获取当前变量的名称,并从变量名称中删除'siameseN/'并用作键。理论上应该可以的。

【讨论】:

【参考方案2】:

我不得不稍微修改一下代码,来编写我自己的恢复函数。我决定将检查点文件作为字典加载,变量名作为键,对应的 numpy 数组作为值,如下所示:

checkpoint_path = '/path/to/checkpoint'
from tensorflow.python import pywrap_tensorflow

reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

key_to_numpy = 
for key in var_to_shape_map:
  key_to_numpy[key] = reader.get_tensor(key)

我已经有了这个创建所有变量的函数,它是从图形N 中以所需名称调用的。我修改它以使用从字典查找中获得的 numpy 数组初始化变量。而且,为了使查找成功,我只是剥离了我添加的父名称范围,如下所示:

init = tf.constant(key_to_numpy[ name.split('siameseN/')[1] ])
var = tf.get_variable(name,  initializer=init)
#var = tf.get_variable(name, shape, initializer=initializer)
return var

这是一种更老套的方法。我没有使用@edit 的答案,因为我已经编写了上面的代码。此外,我所有的权重都是在一个函数中创建的,该函数将这些权重分配给变量var 并返回它。因为这类似于函数式编程,所以变量var 不断被覆盖。 var 永远不会暴露于更高级别的功能。要使用@edit 的答案,我必须为每个初始化使用不同的张量变量名称,并以某种方式将它们公开给更高级别的函数,以便保护程序可以在他们的答案中将它们用作var_to_be_restored_to

但@edit 的解决方案是不那么老套的解决方案,因为它遵循记录在案的用法。所以我会接受这个答案。我所做的可能是另一种解决方案。

【讨论】:

以上是关于TensorFlow 在忽略范围名称或进入新范围名称时恢复的主要内容,如果未能解决你的问题,请参考以下文章

如何将 Excel 列转换为忽略空白单元格的新范围

如何生成列表而忽略日期不符合我要查找的范围的记录?

Kusto 查询渲染函数忽略指定的 Y 范围

Linux 管理员技术

在 C++ 中,阴影变量名称的范围解析(“优先顺序”)是啥?

Linux 云计算-阶段 2-必备知识