在 numpy 中从具有索引的 2D 矩阵构建 3D 布尔矩阵

Posted

技术标签:

【中文标题】在 numpy 中从具有索引的 2D 矩阵构建 3D 布尔矩阵【英文标题】:Building a 3D boolean matrix from a 2D matrix with indices, in numpy 【发布时间】:2022-01-22 07:41:13 【问题描述】:

我有一个形状为 (3, 4) 的二维矩阵,其索引范围从 0 到 8:

a = array([[0, 4, 1, 2],
           [5, 0, 2, 3],
           [8, 6, 0, 5]])

目前,我使用for 循环来构建一个形状为(9, 3, 4) 的3D 布尔数组,该数组将True 存储在每个索引的位置,对于0 到8 之间的每一行:

b = np.zeros((9, 3, 4), dtype=bool)
for i in range(9):
    b[i] = np.where(a == i, True, False)

有没有办法在没有迭代的情况下达到相同的结果,也许使用 numpy 函数?

【问题讨论】:

【参考方案1】:

这是你要找的东西吗?

import numpy as np

a = np.array([[0, 4, 1, 2],
              [5, 0, 2, 3],
              [8, 6, 0, 5]])
y, x = np.mgrid[0:a.shape[0], 0:a.shape[1]]

data = np.zeros((9,) + a.shape, dtype=bool)
data[a, y, x] = True

【讨论】:

【参考方案2】:

利用 numpy 广播的非常简短的解决方案:

b = np.array([a]*9) == np.arange(9).reshape(-1,1,1)

输出:

>>> b
array([[[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False]],

       [[False, False,  True, False],
        [False, False, False, False],
        [False, False, False, False]],

       [[False, False, False,  True],
        [False, False,  True, False],
        [False, False, False, False]],

       [[False, False, False, False],
        [False, False, False,  True],
        [False, False, False, False]],

       [[False,  True, False, False],
        [False, False, False, False],
        [False, False, False, False]],

       [[False, False, False, False],
        [ True, False, False, False],
        [False, False, False,  True]],

       [[False, False, False, False],
        [False, False, False, False],
        [False,  True, False, False]],

       [[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]],

       [[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]]])

【讨论】:

以上是关于在 numpy 中从具有索引的 2D 矩阵构建 3D 布尔矩阵的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Python 中从 Numpy 矩阵创建列表

将 2d 矩阵转换为 3d 单热矩阵 numpy

使用数组索引的numpy数组的2D索引[重复]

Python 2D NumPy 数组理解

使用索引同时从 numpy 2D 数组的行中减去多个值

在二维矩阵中使用 numpy.nanargmin()