Tensorflow 2.2.0收集一个维度上的最大元素。
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow 2.2.0收集一个维度上的最大元素。相关的知识,希望对你有一定的参考价值。
我有以下问题。
我有一个形状为(1600, 29)的张量,我想得到轴1的最大值,结果应该是(1600, 1)张量。为了简化,我将使用(5,3)张量来证明我的问题。
B = tf.constant([[2, 20, 30],
[2, 7, 6],
[3, 11, 16],
[19, 1, 8],
[14, 45, 23]])
x = x = tf.math.argmax(B, 1) --> 5 Values [2 1 2 0 1]
z = tf.gather(B, x, axis=1) --> Shape: (5,5) [[30 20 30 2 20]
[ 6 7 6 2 7]
[16 11 16 3 11]
[ 8 1 8 19 1]
[23 45 23 14 45]]
好了,现在,x给了我轴上的最大元素,然而,tf.gather没有返回[30,7,16,19,45],而是一些奇怪的张量。我如何正确地 "减少 "维度?
我的 "相当 "肮脏的方法是这样的。
eye_z = tf.eye(5, 5)
intermed_result = z*eye_z
result = tf.linalg.matvec(intermed_result,tf.transpose(tf.constant([1,1,1,1,1], dtype=tf.float32)))
这样就能得到正确的张量 [30. 7. 16. 19. 45.]
答案
你有 tf.math.reduce_max
为此。
m = tf.math.reduce_max(B, axis=1)
sess = tf.InteractiveSession()
sess.run(m)
# array([30, 7, 16, 19, 45])
以上是关于Tensorflow 2.2.0收集一个维度上的最大元素。的主要内容,如果未能解决你的问题,请参考以下文章