numpy 索引不保持形状

Posted

技术标签:

【中文标题】numpy 索引不保持形状【英文标题】:numpy indexing doesn't keep shape 【发布时间】:2021-04-04 19:16:09 【问题描述】:

跑步

import numpy as np

a1 = np.arange(1, 5)
a2 = np.arange(2, 6)
a = np.array([a1, a2])
a[a <= 3]

结果

array([1, 2, 3, 2, 3])

相反,我想得到

np.array([[1, 2, 3], [2, 3]])

我应该如何更新上面的代码?我尝试了不同的切片,taketake_along_axis 无济于事。

PS:请注意,以上被认为是一个参差不齐的序列

可见弃用警告: 从参差不齐的嵌套序列创建一个 ndarray(这是一个 列表或元组的列表或元组或具有不同长度的 ndarray 或 形状)已弃用。如果您打算这样做,则必须指定 创建 ndarray 时的 'dtype=object' np.array([[1, 2, 3], [2, 3]])

【问题讨论】:

为什么要创建一个形状不明确的ndarray?您想从a 获得什么信息? 您无法通过索引生成参差不齐的序列。你能做的最好的就是逐行迭代。 @Moosefeather 我想使用生成的 ndarray 来索引另一个 ndarray aka。 b[a]=1.0(这也可能不起作用...) @gliptak 也许你想要的是np.where arr[bool_array] 总是生成一维数组。如果不清楚,您需要重新阅读基本的 numpy 索引文档。 numpy.org/doc/stable/reference/…。并且试图得到一个参差不齐的数组是一个强有力的指标,表明整个数组,无循环,计算是不可能的。 【参考方案1】:

使用 python list 存储生成的 numpy 数组。正如您已经遇到的,numpy 数组不支持不规则形状。

>>> import numpy as np
>>> a1 = np.arange(1, 5)
>>> a2 = np.arange(2, 6)
>>> a = np.array([a1, a2])
>>> inds = a <= 3
>>> inds
array([[ True,  True,  True, False],
       [ True,  True, False, False]])
>>> l = [a[i][inds[i]] for i in range(a.shape[0])]
>>> l
[array([1, 2, 3]), array([2, 3])]

【讨论】:

我宁愿避免每行处理 PS 即使常规形状也不起作用 import numpy as np a1 = np.arange(1, 5) a2 = np.arange(1, 5) a = np .array([a1, a2]) a[a 正如@hpaulj 在 cmets 中指出的那样,通过布尔数组索引数组无论如何都只会选择元素(在布尔值的真实位置)以形成最终的一维数组。只要您有一个常规数组a,即所有a1a2 等具有相同的长度,行处理就可以工作。【参考方案2】:

您可以为此目的使用 numpy 的 masked arrays。

import numpy as np
import numpy.ma as ma

a1 = np.arange(1, 5)
a2 = np.arange(2, 6)
a = np.array([a1, a2])
a_masked = ma.masked_greater(a, 3)

masked_greater 函数返回一个新数组,其中所有大于 3 的值都被屏蔽掉。

请注意

numpy.ma 模块为 numpy 提供了几乎类似工作的替代品 支持带掩码的数据数组。

因此,如果您将屏蔽数组传递给 ma 模块的函数,它将仅对未屏蔽的数组元素进行操作。

【讨论】:

是的,我也尝试过屏蔽数组。我会检查它们是否用于“下游”处理【参考方案3】:

如果你想构造一个元素为列表的一维数组,运行:

a = np.array([[1, 2, 3], [2, 3]], dtype="object")
assert a.shape == (2,)

如果你想构造一个二维数组,那么a1a2 的长度必须相等,但在你的示例中并非如此。

【讨论】:

以上是关于numpy 索引不保持形状的主要内容,如果未能解决你的问题,请参考以下文章

Element Wise函数在两个不同形状的numpy数组中的条目

交错形状不匹配的 NumPy 数组

NumPy之 索引技巧

学习NumPy全套代码超详细基本操作数据类型数组运算复制和试图索引切片和迭代形状操作通用函数线性代数

Numpy 花式索引和赋值

python基础之numpy.reshape详解