TensorFlow 运算符重载

Posted

技术标签:

【中文标题】TensorFlow 运算符重载【英文标题】:TensorFlow operator overloading 【发布时间】:2016-05-07 19:13:13 【问题描述】:

有什么区别

   tf.add(x, y)

   x + y

在 TensorFlow 中?当您使用+ 而不是tf.add() 构建图表时,您的计算图表会有什么不同?

更一般地说,+ 或其他操作是否对张量重载?

【问题讨论】:

【参考方案1】:

如果xy 中的至少一个是tf.Tensor 对象,则表达式tf.add(x, y)x + y 是等价的。您可能使用tf.add() 的主要原因是为创建的操作指定显式的name 关键字参数,这在重载运算符版本中是不可能的。

请注意,如果 xy 都不是 tf.Tensor(例如,如果它们是 NumPy 数组),那么 x + y 将不会创建 TensorFlow 操作。 tf.add() 总是创建一个 TensorFlow 操作并将其参数转换为 tf.Tensor 对象。因此,如果您正在编写一个可能同时接受张量和 NumPy 数组的库函数,您可能更喜欢使用tf.add()

TensorFlow Python API 中重载了以下运算符:

__neg__(一元-__abs__ (abs()) __invert__(一元~__add__(二进制+__sub__(二进制-__mul__(二进制元素*__div__(Python 2 中的二进制 /__floordiv__(Python 3 中的二进制 //__truediv__(Python 3 中的二进制 /__mod__(二进制%__pow__(二进制**__and__(二进制&__or__(二进制|__xor__(二进制^__lt__(二进制<__le__(二进制<=__gt__(二进制>__ge__(二进制>=

请注意,__eq__(二进制 ==没有重载。 x == y 将简单地返回一个 Python 布尔值,无论 xy 是否引用相同的张量。您需要明确使用tf.equal() 来检查元素相等性。不相等也一样,__ne__(二进制!=)。

【讨论】:

如果我们想要== 运算符进行张量标量比较怎么办? @K.Wanter 你需要使用tf.equal()==!= 没有过载。 (或者( x <= y ) & ( y <= x )如果你觉得叛逆。:)) 要补充的一点是,您还可以像在 numpy 中一样使用 @ 重载 tf.matmul() 看起来不像== 这么简单。例如,tf.constant([0]) == tf.constant([0]) 是一个包含[True]tf.Tensor,至少在 tf2.似乎急切的执行有效果【参考方案2】:

Mrry 很好地解释说没有真正的区别。我会在使用tf.add 有益时添加。

tf.add 有一个重要的参数是name。它允许您在图形中命名操作,该图形将在 tensorboard 中可见。所以我的经验法则是,如果在 tensorboard 中命名一个操作有好处,我使用 tf. 等价物,否则我会为了简洁而使用重载版本。

【讨论】:

以上是关于TensorFlow 运算符重载的主要内容,如果未能解决你的问题,请参考以下文章

GroovyGroovy 运算符重载 ( 运算符重载 | 运算符重载对应方法 )

C++ 运算符重载

运算符重载1

什么运算符一定要重载友元函数,什么时候一定要重载为成员函数?

利用运算符重载实现Date类

C++运算符重载