如何在 TensorFlow 图中添加 if 条件?
Posted
技术标签:
【中文标题】如何在 TensorFlow 图中添加 if 条件?【英文标题】:How to add if condition in a TensorFlow graph? 【发布时间】:2016-06-20 08:52:44 【问题描述】:假设我有以下代码:
x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input")
condition = tf.placeholder("int32", shape=[1, 1], name = "condition")
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights")
b = tf.Variable(tf.zeros([label_option]), name = "bias")
if condition > 0:
y = tf.nn.softmax(tf.matmul(x, W) + b)
else:
y = tf.nn.softmax(tf.matmul(x, W) - b)
if
语句在计算中是否有效(我不这么认为)?如果没有,如何在 TensorFlow 计算图中添加 if
语句?
【问题讨论】:
【参考方案1】:TensorFlow 2.0
TF 2.0 introduces a feature called AutoGraph 允许您将 python 代码 JIT 编译为 Graph 执行。这意味着您可以使用 python 控制流语句(是的,这包括 if
语句)。从文档中,
AutoGraph 支持常见的 Python 语句,如
while
、for
、if
、break
、continue
和return
,支持嵌套。这意味着你 可以在while
和if
的条件下使用Tensor表达式 语句,或在for
循环中迭代张量。
您需要定义一个实现您的逻辑的函数并使用tf.function
对其进行注释。这是文档中的一个修改示例:
import tensorflow as tf
@tf.function
def sum_even(items):
s = 0
for c in items:
if tf.equal(c % 2, 0):
s += c
return s
sum_even(tf.constant([10, 12, 15, 20]))
# <tf.Tensor: id=1146, shape=(), dtype=int32, numpy=42>
【讨论】:
你为什么使用tf.equal()
?不应该可以用==
让AutoGraph自动编译吗?
@problemofficer 这是一个很好的问题。我假设(和你一样)同样的事情,但被咬了。这是我问的一个问题,讨论了这种行为:***.com/questions/56616485/…【参考方案2】:
if
语句在这里不起作用是正确的,因为条件是在图形构建时评估的,而您可能希望条件依赖于在运行时提供给占位符的值。 (事实上,它总是采用第一个分支,因为condition > 0
的计算结果为Tensor
,即"truthy" in Python。)
为了支持条件控制流,TensorFlow 提供了tf.cond()
运算符,它根据布尔条件评估两个分支之一。为了向您展示如何使用它,我将重写您的程序,以便为简单起见,condition
是一个标量 tf.int32
值:
x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
condition = tf.placeholder(tf.int32, shape=[], name="condition")
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
b = tf.Variable(tf.zeros([label_option]), name="bias")
y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)
【讨论】:
@mrry 两个分支都默认执行吗?我有 tf.cond(c, lambda x: train_op1, lambda x: train_op2) 并且在每次执行 cond 时都会执行两个 train_ops,而与 c 的值无关。我做错了吗? @PiotrDabkowski 这是tf.cond()
的一个有时令人惊讶的行为,in the docs 涉及到。简而言之,您需要创建要在 inside 相应的 lambdas 中有条件地运行的操作。您在 lambdas 之外创建但在任一分支中引用的所有内容都将在这两种情况下执行。
@mrry 哇,太出乎意料了 :) 感谢您的回答,在函数内部定义操作解决了这个问题。
条件/(逻辑的应用)元素是明智的吗?以上是关于如何在 TensorFlow 图中添加 if 条件?的主要内容,如果未能解决你的问题,请参考以下文章
Magento 1.9 - 如何在 CSS 中添加 if - else 条件?