Tensorflow 2.2.0收集一个维度上的最大元素。

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow 2.2.0收集一个维度上的最大元素。相关的知识,希望对你有一定的参考价值。

我有以下问题。

我有一个形状为(1600, 29)的张量,我想得到轴1的最大值,结果应该是(1600, 1)张量。为了简化,我将使用(5,3)张量来证明我的问题。

B = tf.constant([[2, 20, 30],
                 [2, 7, 6],
                 [3, 11, 16],
                 [19, 1, 8],
                 [14, 45, 23]])

x = x = tf.math.argmax(B, 1) --> 5 Values [2 1 2 0 1]
z = tf.gather(B, x, axis=1) --> Shape: (5,5) [[30 20 30  2 20]
                                              [ 6  7  6  2  7]
                                              [16 11 16  3 11]
                                              [ 8  1  8 19  1]
                                              [23 45 23 14 45]]

好了,现在,x给了我轴上的最大元素,然而,tf.gather没有返回[30,7,16,19,45],而是一些奇怪的张量。我如何正确地 "减少 "维度?

我的 "相当 "肮脏的方法是这样的。

eye_z = tf.eye(5, 5)

intermed_result = z*eye_z

result = tf.linalg.matvec(intermed_result,tf.transpose(tf.constant([1,1,1,1,1], dtype=tf.float32)))

这样就能得到正确的张量 [30. 7. 16. 19. 45.]

答案

你有 tf.math.reduce_max 为此。

m = tf.math.reduce_max(B, axis=1)

sess = tf.InteractiveSession()
sess.run(m)
# array([30,  7, 16, 19, 45])

以上是关于Tensorflow 2.2.0收集一个维度上的最大元素。的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow切片

TensorFlow:在多个维度上采用 L2 范数

TensorFlow2-维度变换

如何在 Keras / Tensorflow 中将(无,)批量维度重新引入张量?

在 Tensorflow 中删除张量的维度

[TensorFlow]Tensor维度理解