获取二维 numpy ndarray 或 numpy 矩阵中前 N 个值的索引

Posted

技术标签:

【中文标题】获取二维 numpy ndarray 或 numpy 矩阵中前 N 个值的索引【英文标题】:Get indices of top N values in 2D numpy ndarray or numpy matrix 【发布时间】:2017-02-07 19:15:21 【问题描述】:

我有一个 N 维向量数组。

data = np.array([[5, 6, 1], [2, 0, 8], [4, 9, 3]])

In [1]: data
Out[1]:
array([[5, 6, 1],
       [2, 0, 8],
       [4, 9, 3]])

我正在使用 sklearn 的 pairwise_distances function 来计算距离值矩阵。请注意,此矩阵关于对角线对称。

dists = pairwise_distances(data)

In [2]: dists
Out[2]:
array([[  0.        ,   9.69535971,   3.74165739],
       [  9.69535971,   0.        ,  10.48808848],
       [  3.74165739,  10.48808848,   0.        ]])

我需要与此矩阵dists 中的前 N ​​个值相对应的索引,因为这些索引将对应于 data 中的成对索引,它们表示它们之间距离最大的向量。

我尝试使用np.argmax(np.max(distances, axis=1)) 获取每行中最大值的索引,并尝试使用np.argmax(np.max(distances, axis=0)) 获取每列中最大值的索引,但请注意:

In [3]: np.argmax(np.max(dists, axis=1))
Out[3]: 1

In [4]: np.argmax(np.max(dists, axis=0))
Out[4]: 1

和:

In [5]: dists[1, 1]
Out[5]: 0.0

因为矩阵关于对角线对称,并且因为 argmax 返回它找到的第一个具有最大值的索引,所以我最终得到与存储最大值的行和列匹配的对角线中的单元格,而不是最高值本身的行和列。

在这一点上,我确信我可以编写更多代码来找到我正在寻找的值,但肯定有一种更简单的方法来做我想做的事情。所以我有两个或多或少等价的问题:

如何找到矩阵中前 N 个值对应的索引如何找到具有前 N 个成对距离的向量来自向量数组?

【问题讨论】:

【参考方案1】:

我会解开,argsort,然后解开。我并不是说这是最好的方法,只是说这是我想到的第一种方法,在有人发布更明显的内容后,我可能会羞愧地删除它。 :-)

也就是说(任意选择前 2 个值):

In [73]: dists = sklearn.metrics.pairwise_distances(data)

In [74]: dists[np.tril_indices_from(dists, -1)] = 0

In [75]: dists
Out[75]: 
array([[  0.        ,   9.69535971,   3.74165739],
       [  0.        ,   0.        ,  10.48808848],
       [  0.        ,   0.        ,   0.        ]])

In [76]: ii = np.unravel_index(np.argsort(dists.ravel())[-2:], dists.shape)

In [77]: ii
Out[77]: (array([0, 1]), array([1, 2]))

In [78]: dists[ii]
Out[78]: array([  9.69535971,  10.48808848])

【讨论】:

这显然行不通,如果你有一个对称的二维数组。【参考方案2】:

作为对 DSM 其他非常好的答案的轻微改进,如果 N 的最大顺序无关紧要,则使用 np.argpartition() 而不是使用 np.argsort() 会更有效。

使用索引i 对数组arr 进行分区会重新排列元素,使得索引i 处的元素是第i 个最大的元素,而左侧的元素更大,而右侧的元素更小。左右的分区不一定是排序的。这样做的好处是它以线性时间运行。

【讨论】:

以上是关于获取二维 numpy ndarray 或 numpy 矩阵中前 N 个值的索引的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch - 获取 'TypeError: pic 应该是 PIL Image 或 ndarray。得到 <class 'numpy.ndarray'>' 错误

TypeError:获取参数数组的类型无效 numpy.ndarray,必须是字符串或张量。 (不能将 ndarray 转换为张量或操作。)

创建没有固定第二维的3D numpy.ndarray

Numpy库:NumPy的数组类ndarray

numpy中的ndarray方法和属性

Python机器学习(五十二)SciPy 基础功能