TensorFlow 运算符重载
Posted
技术标签:
【中文标题】TensorFlow 运算符重载【英文标题】:TensorFlow operator overloading 【发布时间】:2016-05-07 19:13:13 【问题描述】:有什么区别
tf.add(x, y)
和
x + y
在 TensorFlow 中?当您使用+
而不是tf.add()
构建图表时,您的计算图表会有什么不同?
更一般地说,+
或其他操作是否对张量重载?
【问题讨论】:
【参考方案1】:如果x
或y
中的至少一个是tf.Tensor
对象,则表达式tf.add(x, y)
和x + y
是等价的。您可能使用tf.add()
的主要原因是为创建的操作指定显式的name
关键字参数,这在重载运算符版本中是不可能的。
请注意,如果 x
和 y
都不是 tf.Tensor
(例如,如果它们是 NumPy 数组),那么 x + y
将不会创建 TensorFlow 操作。 tf.add()
总是创建一个 TensorFlow 操作并将其参数转换为 tf.Tensor
对象。因此,如果您正在编写一个可能同时接受张量和 NumPy 数组的库函数,您可能更喜欢使用tf.add()
。
TensorFlow Python API 中重载了以下运算符:
__neg__
(一元-
)
__abs__
(abs()
)
__invert__
(一元~
)
__add__
(二进制+
)
__sub__
(二进制-
)
__mul__
(二进制元素*
)
__div__
(Python 2 中的二进制 /
)
__floordiv__
(Python 3 中的二进制 //
)
__truediv__
(Python 3 中的二进制 /
)
__mod__
(二进制%
)
__pow__
(二进制**
)
__and__
(二进制&
)
__or__
(二进制|
)
__xor__
(二进制^
)
__lt__
(二进制<
)
__le__
(二进制<=
)
__gt__
(二进制>
)
__ge__
(二进制>=
)
请注意,__eq__
(二进制 ==
)没有重载。 x == y
将简单地返回一个 Python 布尔值,无论 x
和 y
是否引用相同的张量。您需要明确使用tf.equal()
来检查元素相等性。不相等也一样,__ne__
(二进制!=
)。
【讨论】:
如果我们想要==
运算符进行张量标量比较怎么办?
@K.Wanter 你需要使用tf.equal()
。 ==
和 !=
没有过载。 (或者( x <= y ) & ( y <= x )
如果你觉得叛逆。:))
要补充的一点是,您还可以像在 numpy 中一样使用 @
重载 tf.matmul()
。
看起来不像==
这么简单。例如,tf.constant([0]) == tf.constant([0])
是一个包含[True]
的tf.Tensor
,至少在 tf2.似乎急切的执行有效果【参考方案2】:
Mrry 很好地解释说没有真正的区别。我会在使用tf.add
有益时添加。
tf.add 有一个重要的参数是name
。它允许您在图形中命名操作,该图形将在 tensorboard 中可见。所以我的经验法则是,如果在 tensorboard 中命名一个操作有好处,我使用 tf.
等价物,否则我会为了简洁而使用重载版本。
【讨论】:
以上是关于TensorFlow 运算符重载的主要内容,如果未能解决你的问题,请参考以下文章
GroovyGroovy 运算符重载 ( 运算符重载 | 运算符重载对应方法 )