ndarray每行N个最大值
Posted
技术标签:
【中文标题】ndarray每行N个最大值【英文标题】:N largest values in each row of ndarray 【发布时间】:2015-07-31 17:03:59 【问题描述】:我有一个 ndarray,其中每一行都是一个单独的直方图。对于每一行,我希望找到前 N 个值。
我知道全局前 N 值 (A fast way to find the largest N elements in an numpy array) 的解决方案,但我不知道如何获取每行的前 N。
我可以遍历每一行并应用一维解决方案,但我不应该能够通过 numpy 广播来做到这一点吗?
【问题讨论】:
【参考方案1】:您可以像这样使用np.argsort
和axis = 1
的行 -
import numpy as np
# Find sorted indices for each row
sorted_row_idx = np.argsort(A, axis=1)[:,A.shape[1]-N::]
# Setup column indexing array
col_idx = np.arange(A.shape[0])[:,None]
# Use the column-row indices to get specific elements from input array.
# Please note that since the column indexing array isn't of the same shape
# as the sorted row indices, it will be broadcasted
out = A[col_idx,sorted_row_idx]
示例运行 -
In [417]: A
Out[417]:
array([[0, 3, 3, 2, 5],
[4, 2, 6, 3, 1],
[2, 1, 1, 8, 8],
[6, 6, 3, 2, 6]])
In [418]: N
Out[418]: 3
In [419]: sorted_row_idx = np.argsort(A, axis=1)[:,A.shape[1]-N::]
In [420]: sorted_row_idx
Out[420]:
array([[1, 2, 4],
[3, 0, 2],
[0, 3, 4],
[0, 1, 4]], dtype=int64)
In [421]: col_idx = np.arange(A.shape[0])[:,None]
In [422]: col_idx
Out[422]:
array([[0],
[1],
[2],
[3]])
In [423]: out = A[col_idx,sorted_row_idx]
In [424]: out
Out[424]:
array([[3, 3, 5],
[3, 4, 6],
[2, 8, 8],
[6, 6, 6]])
如果你想让元素按降序排列,你可以使用这个额外的步骤 -
In [425]: out[:,::-1]
Out[425]:
array([[5, 3, 3],
[6, 4, 3],
[8, 8, 2],
[6, 6, 6]])
【讨论】:
您能解释一下 A 的索引是如何获得最高值的吗? 我喜欢这个答案,因为我也可以得到前N行的索引。【参考方案2】:你可以像你链接的问题一样使用np.partition
:排序已经沿着最后一个轴:
In [2]: a = np.array([[ 5, 4, 3, 2, 1],
[10, 9, 8, 7, 6]])
In [3]: b = np.partition(a, -3) # top 3 values from each row
In [4]: b[:,-3:]
Out[4]:
array([[ 3, 4, 5],
[ 8, 9, 10]])
【讨论】:
以上是关于ndarray每行N个最大值的主要内容,如果未能解决你的问题,请参考以下文章