TensorFlow:argmax(-min)

Posted

技术标签:

【中文标题】TensorFlow:argmax(-min)【英文标题】:TensorFlow: argmax (-min) 【发布时间】:2016-10-31 20:15:15 【问题描述】:

我刚刚注意到 TensorFlow 中出现了一个意想不到的(至少对我而言)行为。我以为tf.argmax (-argmin) 从外到内在张量的行列上运行,但显然不是?!

例子:

import numpy as np
import tensorflow as tf

sess = tf.InteractiveSession()

arr = np.array([[31, 23,  4, 24, 27, 34],
                [18,  3, 25,  0,  6, 35],
                [28, 14, 33, 22, 20,  8],
                [13, 30, 21, 19,  7,  9],
                [16,  1, 26, 32,  2, 29],
                [17, 12,  5, 11, 10, 15]])

# arr has rank 2 and shape (6, 6)
tf.rank(arr).eval()
> 2
tf.shape(arr).eval()
> array([6, 6], dtype=int32)

tf.argmax 有两个参数:inputdimension。由于数组arr 的索引是arr[rows, columns],我希望tf.argmax(arr, 0) 返回每​​行最大元素的索引,而我希望tf.argmax(arr, 1) 返回每​​列最大元素。 tf.argmin 也是如此。

然而,事实恰恰相反:

tf.argmax(arr, 0).eval()
> array([0, 3, 2, 4, 0, 1])

# 0 -> 31 (arr[0, 0])
# 3 -> 30 (arr[3, 1])
# 2 -> 33 (arr[2, 2])
# ...
# thus, this is clearly searching for the maximum element
# for every column, and *not* for every row

tf.argmax(arr, 1).eval()
> array([5, 5, 2, 1, 3, 0])

# 5 -> 34 (arr[0, 5])
# 5 -> 35 (arr[1, 5])
# 2 -> 33 (arr[2, 2])
# ...
# this clearly returns the maximum element per row,
# albeit 'dimension' was set to 1

有人可以解释这种行为吗?

推广每个 n 维张量 tt[i, j, k, ...] 索引。因此,t 的秩为 n,形状为 (i, j, k, ...)。由于维度 0 对应于 i,维度 1 对应于 j,依此类推。为什么tf.argmax (& -argmin) 会忽略这个方案?

【问题讨论】:

【参考方案1】:

tf.argmaxdimension 参数视为您减少的轴。 tf.argmax(arr, 0) 减少维度 0,即行。跨行减少意味着您将获得每个单独列的 argmax。

这可能违反直觉,但它符合 tf.reduce_max 等中使用的约定。

【讨论】:

也与numpy的argmax相同的约定 你能解释一下为什么减少跨行意味着获取每个单独列的 argmax 吗?另外:这对于 n 维张量有何表现?我有点迷失在弄清楚哪个维度与减少 5D 张量中的 i, j, k, lm 相关。 根据定义,如果您搜索最大 across 行,您将搜索 within 列。对于任何d 维数组,采用argmax across 轴意味着,对于d-1 剩余索引的任何可能组合,您正在搜索最大 其中 arr[ind1, ind2, ..., ind_i_minus_1, : , ind_i_plus_1, ..., ind_d].【参考方案2】:

在一个 n 维张量中,任何给定的维度都有 n-1 个维度,形成一个离散的二维子空间。按照同样的逻辑,它有 n-2 个 3 维子空间,一直到 n - (n-1), n 维子空间。您可以将任何聚合表示为剩余子空间内的函数,或跨越正在聚合的子空间。由于聚合后子空间将不再存在,Tensorflow 选择将其实现为跨该维度的操作。

坦率地说,它是 Tensorflow 的创建者的实现选择,现在你知道了。

【讨论】:

以上是关于TensorFlow:argmax(-min)的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow-argmax

argmax in DNN

TensorFlow函数 tf.argmax()

TensorFlow函数 tf.argmax()

tensorflow数据统计

如何在 Cplex (java) 中添加关于 argmax o min 的约束?