在numpy数组中找到第n个最小的元素[重复]

Posted

技术标签:

【中文标题】在numpy数组中找到第n个最小的元素[重复]【英文标题】:Find nth smallest element in numpy array 【发布时间】:2018-09-01 15:52:39 【问题描述】:

我只需要找到一维 numpy.array 中最小的第 n 个元素。

例如:

a = np.array([90,10,30,40,80,70,20,50,60,0])

我想获得第五小的元素,所以我想要的输出是40

我目前的解决方案是这样的:

result = np.max(np.partition(a, 5)[:5])

但是,找到 5 个最小的元素,然后取出最大的一个,这对我来说似乎有点笨拙。有更好的方法吗?我是否缺少一个可以实现目标的功能?

有些问题的标题与此类似,但我没有看到任何回答我的问题的内容。

编辑:

我本来应该提到它,但性能对我来说非常重要;因此,heapq 解决方案虽然不错,但对我不起作用。

import numpy as np
import heapq

def find_nth_smallest_old_way(a, n):
    return np.max(np.partition(a, n)[:n])

# Solution suggested by Jaime and HYRY    
def find_nth_smallest_proper_way(a, n):
    return np.partition(a, n-1)[n-1]

def find_nth_smallest_heapq(a, n):
    return heapq.nsmallest(n, a)[-1]
#    
n_iterations = 10000

a = np.arange(1000)
np.random.shuffle(a)

t1 = timeit('find_nth_smallest_old_way(a, 100)', 'from __main__ import find_nth_smallest_old_way, a', number = n_iterations)
print 'time taken using partition old_way: '.format(t1)    
t2 = timeit('find_nth_smallest_proper_way(a, 100)', 'from __main__ import find_nth_smallest_proper_way, a', number = n_iterations)
print 'time taken using partition proper way: '.format(t2) 
t3 = timeit('find_nth_smallest_heapq(a, 100)', 'from __main__ import find_nth_smallest_heapq, a', number = n_iterations)  
print 'time taken using heapq : '.format(t3)

结果:

time taken using partition old_way: 0.255564928055
time taken using partition proper way: 0.129678010941
time taken using heapq : 7.81094002724

【问题讨论】:

另外,查看docs.python.org/2/library/heapq.html 可能会有所帮助 @C.B.上述问题与我的问题有很大不同;它要求最小值和最大值,它是二维矩阵 这是怎么复制的?标题听起来很相似,但问题本身却大不相同。有时不同的问题会得出相同的答案,但这里的答案也大不相同。这个问题的答案不可能是我问题的答案。 【参考方案1】:

除非我遗漏了什么,否则你想做的是:

>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> np.partition(a, 4)[4]
40

np.partition(a, k) 会将a 的第一个最小元素k 放在a[k] 中,较小的值放在a[:k] 中,较大的值放在a[k+1:] 中。唯一需要注意的是,由于 0 索引,第五个元素位于索引 4。

【讨论】:

是的,就是这样。我想错了。我知道有更好的解决方案! 应该是 np.partition(a, 4)[3] 好的,第 5 个元素。 发现k必须大于等于括号[]中的数字。否则会弹出错误的答案(我预计这是一个错误)。我留下这个评论是为了防止有人滥用它来得到错误的答案【参考方案2】:

你可以使用heapq.nsmallest:

>>> import numpy as np
>>> import heapq
>>> 
>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> heapq.nsmallest(5, a)[-1]
40

【讨论】:

不过,请检查您的表现。我最近遇到了一种情况,heapq.nsmallest 看起来很完美,但切片sorted 的速度大约快了 25%。我相信堆方法对于某些数据来说更快,但并非对所有数据都适用。我不知道 numpy 数组是否有什么特别之处会影响这种方式或其他方式。 @PeterDeGlopper 好吧,对于较小的数据集,排序方法可能更快,但对于较大的数据集,堆方法应该更快。您指的数据有多大? 不大 - 大约 100 个 3 元组整数的列表。所以可能远低于堆方法获胜的水平。 我在原始帖子中的解决方案是 O(n),因为 np.partitionnp.max 都是 O(n)。 我见过一些实例,其中实际的 heapify 加上 n 个 heappop 操作比使用 nsmallest 或切片 sorted 快得多。只是把它扔在那里。【参考方案3】:

你不需要打电话numpy.max():

def nsmall(a, n):
    return np.partition(a, n)[n]

【讨论】:

应该是 np.partition(a, n)[n-1]

以上是关于在numpy数组中找到第n个最小的元素[重复]的主要内容,如果未能解决你的问题,请参考以下文章

如何在两个排序数组的并集中找到第 k 个最小的元素?

numpy数组列表中N个最高元素的索引[重复]

从 n 个排序数组中找到第 k 个最小的数

如何找到 Numpy 数组的 M 个元素的 N 个最大乘积子数组?

我需要一个 numpy 数组中的 N 个最小(索引)值

在python中找到一个numpy数组中元素的位置[重复]