在TensorFlow中,函数'tf.one_hot'中的参数'axis'是啥

Posted

技术标签:

【中文标题】在TensorFlow中,函数\'tf.one_hot\'中的参数\'axis\'是啥【英文标题】:In TensorFlow, what is the argument 'axis' in the function 'tf.one_hot'在TensorFlow中,函数'tf.one_hot'中的参数'axis'是什么 【发布时间】:2018-06-13 10:53:18 【问题描述】:

谁能帮忙解释一下axisTensorFlowone_hot函数中是什么?

根据documentation:

axis:要填充的轴(默认值:-1,新的最内层轴)

最近我在SO was an explanation 上找到了与Pandas 相关的答案:

不确定上下文是否同样适用。

【问题讨论】:

【参考方案1】:

这是一个例子:

x = tf.constant([0, 1, 2])

... 是输入张量和N=4(每个索引都转换为4D向量)。

axis=-1

计算 one_hot_1 = tf.one_hot(x, 4).eval() 会产生一个 (3, 4) 张量:

[[ 1.  0.  0.  0.]
 [ 0.  1.  0.  0.]
 [ 0.  0.  1.  0.]]

... 最后一个维度是单热编码的(清晰可见)。这对应于默认的axis=-1,即最后一个

axis=0

现在,计算 one_hot_2 = tf.one_hot(x, 4, axis=0).eval() 会产生一个 (4, 3) 张量,它不能立即识别为 one-hot 编码:

[[ 1.  0.  0.]
 [ 0.  1.  0.]
 [ 0.  0.  1.]
 [ 0.  0.  0.]]

这是因为 one-hot 编码是沿 0 轴完成的,必须转置矩阵才能看到之前的编码。当输入的维度更高时,情况会变得更加复杂,但想法是一样的:不同之处在于用于 one-hot 编码的 额外 维度的位置。

【讨论】:

谢谢你的解释,虽然它对我来说非常密集,所以一次问一个。你怎么知道x = tf.constant([[1, 1, 2], [0, 1, 2]]) 会产生一个4D 向量?...是不是它是一个由二维数组作为元素的数组? 我想我需要阅读。我有太多的问题。有没有可能用二维数组来简化你的答案?......我认为,如果失败了,我将不得不回到第一原则......毫无疑问,你的答案可能是正确的 没关系。我试图为每件事选择不同的维度以避免误解。输入x(2, 3)。编码结果是 4D,因为我们设置了N=4(想想类的数量)。这就是结果为(2, 3, 4)(4, 3, 2) 的原因,具体取决于展示位置。 实际上,是的,您也可以看到与x = tf.constant([0, 1, 2]) 的区别:结果是(3, 4)(4, 3) 哈哈。我要回去阅读第一原则。我绝对不明白。与您的回答无关,但我对这个主题缺乏深度。谢谢,不管怎样。将其标记为信任您对该主题的了解的答案...然后我将重新阅读:)【参考方案2】:

对我来说,轴转换为“你在哪里添加额外的数字来增加维度”。至少我是这样解释它并作为助记符的。

例如,您有 [1,2,3,0,2,1],它的形状为 (1,6)。这意味着它是一个一维数组。 one_hot 在原始数组的每个位置添加零并将位置转换为 1,为此,原始数组必须比原始数组多 1 个维度,并且轴告诉函数在哪里添加它,这个新维度将识别示例


轴=1

您添加第二个维度并保留第一个维度。这将产生一个 (6,4) 数组。因此,对于结果数组,您使用第一个维度 (0) 来了解您看到的是哪个示例,并使用第二个维度 (1,新的) 来了解该类是否处于活动状态。 newArr[0][1]=1 表示示例 0,类 1,在这种情况下表示示例 0 属于类 1。
   0   1   2   3  <- class

[[ 0.  1.  0.  0.]   <- example 0
 [ 0.  0.  1.  0.]   <- example 1
 [ 0.  0.  0.  1.]   <- example 2
 [ 1.  0.  0.  0.]   <- example 3
 [ 0.  0.  1.  0.]   <- example 4
 [ 0.  1.  0.  0.]]  <- example 5

轴=0

您添加第一个维度并移动现有维度。这将产生一个 (4,6) 数组。因此,对于结果数组,您使用第一个维度(0,新维度)来了解该类是否处于活动状态,并使用第二个维度 (1) 来了解您看到的示例。 newArr[0][1]=0 表示 0 类,示例 1,在这种情况下表示示例 1 不属于 0 类。
   0   1   2   3   4   5  <- example

[[ 0.  0.  0.  1.  0.  0.]   <- class 0
 [ 1.  0.  0.  0.  0.  1.]   <- class 1
 [ 0.  1.  0.  0.  1.  0.]   <- class 2
 [ 0.  0.  1.  0.  0.  0.]]  <- class 3

【讨论】:

很好的解释,但输入数组不是 (1, 6) 而不是 (6, 1)? 我对你的回答不满意...对不起【参考方案3】:

对我来说,我是这样理解的—— (注意documentation中提到的索引只是类标签的信息,它可以是标量或向量或矩阵) 如果您的索引只是一个标量,则不需要轴。 但是,如果它是一个向量,您可以选择特征和类的方向2` 这里单热向量的图像有一行作为深度(类)和一列作为相应的特征(标签),因此对于这种情况,轴的值为 0。 同样,如果您想要特征 x 深度,则轴的值为 -1。

同样,如果索引是矩阵,那么您可以选择以下方向


(batch 表示索引中的行)3

batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0

【讨论】:

以上是关于在TensorFlow中,函数'tf.one_hot'中的参数'axis'是啥的主要内容,如果未能解决你的问题,请参考以下文章

如何在 TensorFlow 中使用“group_by_window”函数

在 tensorflow 中编写自定义成本函数

有没有办法知道在tensorflow中调用了哪个c++核心函数?

在 TensorFlow 的 SKFlow 模型训练中应用自定义成本函数

tensorflow官方文档中的sub 和mul中的函数已经在API中改名了

在TensorFlow中,函数'tf.one_hot'中的参数'axis'是啥