One-Hot Encode numpy 数组与 >2 暗淡
Posted
技术标签:
【中文标题】One-Hot Encode numpy 数组与 >2 暗淡【英文标题】:One-Hot Encode numpy array with >2 dims 【发布时间】:2020-12-29 12:03:50 【问题描述】:我有一个形状为 (192, 224, 192, 1)
的 numpy 数组。最后一个维度是我想要一个热编码的整数类。例如,如果我有 12 个类,我希望结果数组的 为(192, 224, 192, 12)
,最后一个维度全为零,但在对应于原始值的索引处为 1。
我可以通过许多for
循环天真地做到这一点,但想知道是否有更好的方法来做到这一点。
【问题讨论】:
【参考方案1】:如果您知道最大值,则可以在单个索引操作中执行此操作。给定一个数组a
和m = a.max() + 1
:
out = np.zeros(a.shape[:-1] + (m,), dtype=bool)
out[(*np.indices(a.shape[:-1], sparse=True), a[..., 0])] = True
如果你删除不必要的尾随维度会更容易:
a = np.squeeze(a)
out = np.zeros(a.shape + (m,), bool)
out[(*np.indices(a.shape, sparse=True), a)] = True
索引中的显式元组是进行星形扩展所必需的。
如果您想将其扩展到任意维度,您也可以这样做。下面将在axis
的压缩数组中插入一个新维度。这里axis
是新轴在最终数组中的位置,与saynp.stack
一致,但与list.insert
不一致:
def onehot(a, axis=-1, dtype=bool):
pos = axis if axis >= 0 else a.ndim + axis + 1
shape = list(a.shape)
shape.insert(pos, a.max() + 1)
out = np.zeros(shape, dtype)
ind = list(np.indices(a.shape, sparse=True))
ind.insert(pos, a)
out[tuple(ind)] = True
return out
如果你有一个单例维度要扩展,广义的解决方案可以找到第一个可用的单例维度:
def onehot2(a, axis=None, dtype=bool):
shape = np.array(a.shape)
if axis is None:
axis = (shape == 1).argmax()
if shape[axis] != 1:
raise ValueError(f'Dimension at axis is non-singleton')
shape[axis] = a.max() + 1
out = np.zeros(shape, dtype)
ind = list(np.indices(a.shape, sparse=True))
ind[axis] = a
out[tuple(ind)] = True
return out
要使用最后一个可用的单例,请将axis = (shape == 1).argmax()
替换为
axis = a.ndim - 1 - (shape[::-1] == 1).argmax()
以下是一些示例用法:
>>> np.random.seed(0x111)
>>> x = np.random.randint(5, size=(3, 2))
>>> x
array([[2, 3],
[3, 1],
[4, 0]])
>>> a = onehot(x, axis=-1, dtype=int)
>>> a.shape
(3, 2, 5)
>>> a
array([[[0, 0, 1, 0, 0], # 2
[0, 0, 0, 1, 0]], # 3
[[0, 0, 0, 1, 0], # 3
[0, 1, 0, 0, 0]], # 1
[[0, 0, 0, 0, 1], # 4
[1, 0, 0, 0, 0]]] # 0
>>> b = onehot(x, axis=-2, dtype=int)
>>> b.shape
(3, 5, 2)
>>> b
array([[[0, 0],
[0, 0],
[1, 0],
[0, 1],
[0, 0]],
[[0, 0],
[0, 1],
[0, 0],
[1, 0],
[0, 0]],
[[0, 1],
[0, 0],
[0, 0],
[0, 0],
[1, 0]]])
onehot2
要求您将要添加的维度标记为单例:
>>> np.random.seed(0x111)
>>> y = np.random.randint(5, size=(3, 1, 2, 1))
>>> y
array([[[[2],
[3]]],
[[[3],
[1]]],
[[[4],
[0]]]])
>>> c = onehot2(y, axis=-1, dtype=int)
>>> c.shape
(3, 1, 2, 5)
>>> c
array([[[[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0]]],
[[[0, 0, 0, 1, 0],
[0, 1, 0, 0, 0]]],
[[[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0]]]])
>>> d = onehot2(y, axis=-2, dtype=int)
ValueError: Dimension at -2 is non-singleton
>>> e = onehot2(y, dtype=int)
>>> e.shape
(3, 5, 2, 1)
>>> e.squeeze()
array([[[0, 0],
[0, 0],
[1, 0],
[0, 1],
[0, 0]],
[[0, 0],
[0, 1],
[0, 0],
[1, 0],
[0, 0]],
[[0, 1],
[0, 0],
[0, 0],
[0, 0],
[1, 0]]])
【讨论】:
看到np.indices
被使用很有趣,我需要获得更多关于花式索引的经验
@RichieV。我已恢复您的编辑。 onehot
中的索引是故意以这种方式完成的。它的目的是在问题中对a.squeeze
而不是a
进行操作。但你对这个错误是正确的:)
@RichieV。我添加了一些示例来说明如何使用这两个函数,以符合您的测试精神。
感谢代码。与其他一些答案相比,这非常棒,而且速度非常快。
@PDPDPDPD。 RichieV 的回答非常相似。如果速度很重要,我会将它与我的基准进行比较。分解和分解非常便宜,因为它们不会复制内存。【参考方案2】:
您可以创建一个新的 zeros 数组并使用高级索引填充它。
# sample array with 12 classes
np.random.seed(123)
a = np.random.randint(0, 12, (192, 224, 192, 1))
b = np.zeros((a.size, a.max() + 1))
# use advanced indexing to get one-hot encoding
b[np.arange(a.size), a.ravel()] = 1
# reshape to original form
b = b.reshape(a.shape[:-1] + (a.max() + 1,))
print(b.shape)
print(a[0, 0, 0])
print(b[0, 0, 0])
输出
(192, 224, 192, 12)
[2]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
类似于this answer,但具有数组重塑功能。
【讨论】:
如果不reshape,总索引数组会变短 @RitchieV。我已经发布了一个答案并将其推广到任意维度 如果你有机会玩它,请告诉我。我是从手机发布的,所以不能保证代码没有错误 这个答案对我的问题很有效。我必须做的唯一更改是将a.max() + 1
更改为我拥有的课程数量。这个特殊的 ML 问题是分段,所以整个数组是我的标签,但不是每个类都在每个标签中表示,所以它必须是硬编码的。
@PDPDPDPD 考虑支持 Mad 的答案,它实际上性能更好并且包含一个通用函数,很高兴您修复了您的代码!【参考方案3】:
SciKit-learn 有一个编码器:
from sklearn.preprocessing import OneHotEncoder
# Data
values = np.array([1, 3, 2, 4, 1, 2, 1, 3, 5])
val_reshape = values.reshape(len(values), 1)
# One-hot encoding
oh = OneHotEncoder(sparse = False)
oh_arr = oh.fit_transform(val_reshape)
print(oh_arr)
output:
[[1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 0. 0. 1.]]
【讨论】:
以上是关于One-Hot Encode numpy 数组与 >2 暗淡的主要内容,如果未能解决你的问题,请参考以下文章
Numpy与PandasSklearn中one-hot快速编码方法
Numpy与PandasSklearn中one-hot快速编码方法
使用 NumPy 从 Python 中的位置向量中没有 for 循环的 One-Hot 编码?
python使用sklearn中的MultiLabelBinarizer函数将多标签的分类变量进行独热编码(One-Hot Encode Features With Multiple Labels)