在TensorFlow中,函数'tf.one_hot'中的参数'axis'是啥
Posted
技术标签:
【中文标题】在TensorFlow中,函数\'tf.one_hot\'中的参数\'axis\'是啥【英文标题】:In TensorFlow, what is the argument 'axis' in the function 'tf.one_hot'在TensorFlow中,函数'tf.one_hot'中的参数'axis'是什么 【发布时间】:2018-06-13 10:53:18 【问题描述】:谁能帮忙解释一下axis
在TensorFlow
的one_hot
函数中是什么?
根据documentation:
axis:要填充的轴(默认值:-1,新的最内层轴)
最近我在SO was an explanation 上找到了与Pandas 相关的答案:
不确定上下文是否同样适用。
【问题讨论】:
【参考方案1】:这是一个例子:
x = tf.constant([0, 1, 2])
... 是输入张量和N=4
(每个索引都转换为4D向量)。
axis=-1
计算 one_hot_1 = tf.one_hot(x, 4).eval()
会产生一个 (3, 4)
张量:
[[ 1. 0. 0. 0.]
[ 0. 1. 0. 0.]
[ 0. 0. 1. 0.]]
... 最后一个维度是单热编码的(清晰可见)。这对应于默认的axis=-1
,即最后一个。
axis=0
现在,计算 one_hot_2 = tf.one_hot(x, 4, axis=0).eval()
会产生一个 (4, 3)
张量,它不能立即识别为 one-hot 编码:
[[ 1. 0. 0.]
[ 0. 1. 0.]
[ 0. 0. 1.]
[ 0. 0. 0.]]
这是因为 one-hot 编码是沿 0 轴完成的,必须转置矩阵才能看到之前的编码。当输入的维度更高时,情况会变得更加复杂,但想法是一样的:不同之处在于用于 one-hot 编码的 额外 维度的位置。
【讨论】:
谢谢你的解释,虽然它对我来说非常密集,所以一次问一个。你怎么知道x = tf.constant([[1, 1, 2], [0, 1, 2]])
会产生一个4D
向量?...是不是它是一个由二维数组作为元素的数组?
我想我需要阅读。我有太多的问题。有没有可能用二维数组来简化你的答案?......我认为,如果失败了,我将不得不回到第一原则......毫无疑问,你的答案可能是正确的
没关系。我试图为每件事选择不同的维度以避免误解。输入x
是(2, 3)
。编码结果是 4D,因为我们设置了N=4
(想想类的数量)。这就是结果为(2, 3, 4)
或(4, 3, 2)
的原因,具体取决于展示位置。
实际上,是的,您也可以看到与x = tf.constant([0, 1, 2])
的区别:结果是(3, 4)
或(4, 3)
哈哈。我要回去阅读第一原则。我绝对不明白。与您的回答无关,但我对这个主题缺乏深度。谢谢,不管怎样。将其标记为信任您对该主题的了解的答案...然后我将重新阅读:)【参考方案2】:
对我来说,轴转换为“你在哪里添加额外的数字来增加维度”。至少我是这样解释它并作为助记符的。
例如,您有 [1,2,3,0,2,1],它的形状为 (1,6)。这意味着它是一个一维数组。 one_hot 在原始数组的每个位置添加零并将位置转换为 1,为此,原始数组必须比原始数组多 1 个维度,并且轴告诉函数在哪里添加它,这个新维度将识别示例。
轴=1
您添加第二个维度并保留第一个维度。这将产生一个 (6,4) 数组。因此,对于结果数组,您使用第一个维度 (0) 来了解您看到的是哪个示例,并使用第二个维度 (1,新的) 来了解该类是否处于活动状态。 newArr[0][1]=1 表示示例 0,类 1,在这种情况下表示示例 0 属于类 1。 0 1 2 3 <- class
[[ 0. 1. 0. 0.] <- example 0
[ 0. 0. 1. 0.] <- example 1
[ 0. 0. 0. 1.] <- example 2
[ 1. 0. 0. 0.] <- example 3
[ 0. 0. 1. 0.] <- example 4
[ 0. 1. 0. 0.]] <- example 5
轴=0
您添加第一个维度并移动现有维度。这将产生一个 (4,6) 数组。因此,对于结果数组,您使用第一个维度(0,新维度)来了解该类是否处于活动状态,并使用第二个维度 (1) 来了解您看到的示例。 newArr[0][1]=0 表示 0 类,示例 1,在这种情况下表示示例 1 不属于 0 类。 0 1 2 3 4 5 <- example
[[ 0. 0. 0. 1. 0. 0.] <- class 0
[ 1. 0. 0. 0. 0. 1.] <- class 1
[ 0. 1. 0. 0. 1. 0.] <- class 2
[ 0. 0. 1. 0. 0. 0.]] <- class 3
【讨论】:
很好的解释,但输入数组不是 (1, 6) 而不是 (6, 1)? 我对你的回答不满意...对不起【参考方案3】:对我来说,我是这样理解的—— (注意documentation中提到的索引只是类标签的信息,它可以是标量或向量或矩阵) 如果您的索引只是一个标量,则不需要轴。 但是,如果它是一个向量,您可以选择特征和类的方向2` 这里单热向量的图像有一行作为深度(类)和一列作为相应的特征(标签),因此对于这种情况,轴的值为 0。 同样,如果您想要特征 x 深度,则轴的值为 -1。
同样,如果索引是矩阵,那么您可以选择以下方向
(batch 表示索引中的行)3
batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0
【讨论】:
以上是关于在TensorFlow中,函数'tf.one_hot'中的参数'axis'是啥的主要内容,如果未能解决你的问题,请参考以下文章
如何在 TensorFlow 中使用“group_by_window”函数
有没有办法知道在tensorflow中调用了哪个c++核心函数?
在 TensorFlow 的 SKFlow 模型训练中应用自定义成本函数