获取二维 numpy ndarray 或 numpy 矩阵中前 N 个值的索引
Posted
技术标签:
【中文标题】获取二维 numpy ndarray 或 numpy 矩阵中前 N 个值的索引【英文标题】:Get indices of top N values in 2D numpy ndarray or numpy matrix 【发布时间】:2017-02-07 19:15:21 【问题描述】:我有一个 N 维向量数组。
data = np.array([[5, 6, 1], [2, 0, 8], [4, 9, 3]])
In [1]: data
Out[1]:
array([[5, 6, 1],
[2, 0, 8],
[4, 9, 3]])
我正在使用 sklearn 的 pairwise_distances
function 来计算距离值矩阵。请注意,此矩阵关于对角线对称。
dists = pairwise_distances(data)
In [2]: dists
Out[2]:
array([[ 0. , 9.69535971, 3.74165739],
[ 9.69535971, 0. , 10.48808848],
[ 3.74165739, 10.48808848, 0. ]])
我需要与此矩阵dists
中的前 N 个值相对应的索引,因为这些索引将对应于 data
中的成对索引,它们表示它们之间距离最大的向量。
我尝试使用np.argmax(np.max(distances, axis=1))
获取每行中最大值的索引,并尝试使用np.argmax(np.max(distances, axis=0))
获取每列中最大值的索引,但请注意:
In [3]: np.argmax(np.max(dists, axis=1))
Out[3]: 1
In [4]: np.argmax(np.max(dists, axis=0))
Out[4]: 1
和:
In [5]: dists[1, 1]
Out[5]: 0.0
因为矩阵关于对角线对称,并且因为 argmax 返回它找到的第一个具有最大值的索引,所以我最终得到与存储最大值的行和列匹配的对角线中的单元格,而不是最高值本身的行和列。
在这一点上,我确信我可以编写更多代码来找到我正在寻找的值,但肯定有一种更简单的方法来做我想做的事情。所以我有两个或多或少等价的问题:
如何找到矩阵中前 N 个值对应的索引、或、如何找到具有前 N 个成对距离的向量来自向量数组?
【问题讨论】:
【参考方案1】:我会解开,argsort,然后解开。我并不是说这是最好的方法,只是说这是我想到的第一种方法,在有人发布更明显的内容后,我可能会羞愧地删除它。 :-)
也就是说(任意选择前 2 个值):
In [73]: dists = sklearn.metrics.pairwise_distances(data)
In [74]: dists[np.tril_indices_from(dists, -1)] = 0
In [75]: dists
Out[75]:
array([[ 0. , 9.69535971, 3.74165739],
[ 0. , 0. , 10.48808848],
[ 0. , 0. , 0. ]])
In [76]: ii = np.unravel_index(np.argsort(dists.ravel())[-2:], dists.shape)
In [77]: ii
Out[77]: (array([0, 1]), array([1, 2]))
In [78]: dists[ii]
Out[78]: array([ 9.69535971, 10.48808848])
【讨论】:
这显然行不通,如果你有一个对称的二维数组。【参考方案2】:作为对 DSM 其他非常好的答案的轻微改进,如果 N 的最大顺序无关紧要,则使用 np.argpartition()
而不是使用 np.argsort()
会更有效。
使用索引i
对数组arr
进行分区会重新排列元素,使得索引i
处的元素是第i 个最大的元素,而左侧的元素更大,而右侧的元素更小。左右的分区不一定是排序的。这样做的好处是它以线性时间运行。
【讨论】:
以上是关于获取二维 numpy ndarray 或 numpy 矩阵中前 N 个值的索引的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch - 获取 'TypeError: pic 应该是 PIL Image 或 ndarray。得到 <class 'numpy.ndarray'>' 错误
TypeError:获取参数数组的类型无效 numpy.ndarray,必须是字符串或张量。 (不能将 ndarray 转换为张量或操作。)