用在张量流中具有变量依赖性的自定义操作替换图中的节点

Posted

技术标签:

【中文标题】用在张量流中具有变量依赖性的自定义操作替换图中的节点【英文标题】:Replacing a node in graph with custom op having variable dependency in tensorflow 【发布时间】:2016-10-05 20:17:30 【问题描述】:

我正在尝试用执行相同操作的自定义操作替换图中完成的计算。

假设图形有一个常量A 和权重变量W,我创建自定义操作来获取这两个输入并进行整个计算(除了权重更新的最后一步):

custom_op_tensor = custom_module.custom_op([A,W])
g_def = tf.get_default_graph().as_graph_def()
input_map =  tensor.name : custom_op_tensor 
train_op, = tf.import_graph_def(g_def, input_map=input_map, return_elements=[train_op])

在导入图def之后,有两个W,一个来自原始图def,一个在导入的图中。当我们运行训练操作时,自定义操作最终会读取旧的W,而新的W 会更新。结果,梯度下降最终无法做正确的事情。

问题是custom_op的实例化需要输入权重张量W。新的W 仅在导入后才知道。而且,导入需要自定义操作。 如何解决这个问题?

【问题讨论】:

您在询问如何用另一个操作替换图中的操作。直到最近,图表都是仅附加的,不可能做到这一点。但是,最近添加了一个图形编辑器库,也许那里有一些功能可以提供帮助——tensorflow.org/versions/r0.11/api_docs/python/… 【参考方案1】:

您能否详细说明您使用的是哪个版本的 Tensorflow:r0.08、r0.09、r0.10、r0.11?

这是不可能用另一个操作来改变图中的一个操作的。 但是如果您可以访问 W,您仍然可以在运行更新它的 train op 之前制作 W 的备份副本(使用 deepcopy() from copy module )?

问候

【讨论】:

以上是关于用在张量流中具有变量依赖性的自定义操作替换图中的节点的主要内容,如果未能解决你的问题,请参考以下文章

Interlocked.Exchange() 具有依赖于读取锁定变量的自定义条件

有啥方法可以查看张量板图中的参数总数?

在张量流中,如何迭代存储在张量中的输入序列?

有没有办法在张量流中剪辑中间爆炸梯度

如何在张量流中使用非常大(> 2M)的词嵌入?

tf.shape() 在张量流中得到错误的形状