在numpy中计算超过阈值的数组值的最快方法

Posted

技术标签:

【中文标题】在numpy中计算超过阈值的数组值的最快方法【英文标题】:Fastest way to count array values above a threshold in numpy 【发布时间】:2014-02-17 12:12:22 【问题描述】:

我有一个包含 10^8 个浮点数的 numpy 数组,我想计算其中有多少是 >= 给定阈值。速度至关重要,因为必须在大量此类阵列上进行操作。到目前为止的参赛者是

np.sum(myarray >= thresh)

np.size(np.where(np.reshape(myarray,-1) >= thresh))

Count all values in a matrix greater than a value 的答案表明 np.where() 会更快,但我发现时序结果不一致。我的意思是 some 实现和布尔条件 np.size(np.where(cond)) 比 np.sum(cond) 快,但对于某些人来说它更慢。

具体来说,如果大部分条目满足条件,则 np.sum(cond) 明显更快,但如果一小部分(可能小于十分之一)满足条件,则 np.size(np.where(cond)) 获胜.

问题分为两部分:

还有其他建议吗? np.size(np.where(cond)) 所花费的时间是否会随着 cond 为真的条目数增加而增加?

【问题讨论】:

numexpr 或 numba 可能会通过避免创建中间数组来加快速度。 还有 np.count_nonzero 在新的 numpy 版本上比布尔求和快得多。 【参考方案1】:

使用 cython 可能是一个不错的选择。

import numpy as np
cimport numpy as np
cimport cython
from cython.parallel import prange


DTYPE_f64 = np.float64
ctypedef np.float64_t DTYPE_f64_t


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef int count_above_cython(DTYPE_f64_t [:] arr_view, DTYPE_f64_t thresh) nogil:

    cdef int length, i, total
    total = 0
    length = arr_view.shape[0]

    for i in prange(length):
        if arr_view[i] >= thresh:
            total += 1

    return total


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def count_above(np.ndarray arr, DTYPE_f64_t thresh):

    cdef DTYPE_f64_t [:] arr_view = arr.ravel()
    cdef int total

    with nogil:
       total =  count_above_cython(arr_view, thresh)
    return total

不同建议方法的时间安排。

myarr = np.random.random((1000,1000))
thresh = 0.33

In [6]: %timeit count_above(myarr, thresh)
1000 loops, best of 3: 693 µs per loop

In [9]: %timeit np.count_nonzero(myarr >= thresh)
100 loops, best of 3: 4.45 ms per loop

In [11]: %timeit np.sum(myarr >= thresh)
100 loops, best of 3: 4.86 ms per loop

In [12]: %timeit np.size(np.where(np.reshape(myarr,-1) >= thresh))
10 loops, best of 3: 61.6 ms per loop

使用更大的数组:

In [13]: myarr = np.random.random(10**8)

In [14]: %timeit count_above(myarr, thresh)
10 loops, best of 3: 63.4 ms per loop

In [15]: %timeit np.count_nonzero(myarr >= thresh)
1 loops, best of 3: 473 ms per loop

In [16]: %timeit np.sum(myarr >= thresh)
1 loops, best of 3: 511 ms per loop

In [17]: %timeit np.size(np.where(np.reshape(myarr,-1) >= thresh))
1 loops, best of 3: 6.07 s per loop

【讨论】:

我猜这将取决于硬件,而在 cython 中,您可以更轻松地进行并行化。使用 cython 中的 -O3 (没有它的慢速)和开发 numpy,在我的计算机上它们的性能几乎相同(cython 略有优势,但 numpy 代码对于不连续的数组要快得多,尽管你当然可以修复它)。但是,您应该真正使用ssize_t/np.intp_tnot int,否则这是一个错误。

以上是关于在numpy中计算超过阈值的数组值的最快方法的主要内容,如果未能解决你的问题,请参考以下文章

在 python 或 spark 中获取大数据缺失值的最快方法是啥?

计算 numpy 数组和 csr_matrix 之间的成对最小值的最有效方法

左循环 numpy 数组的最快方法(如弹出、推送队列)

使用 NumPy 将 ubyte [0, 255] 数组转换为浮点数组 [-0.5, +0.5] 的最快方法

将交错的 NumPy 整数数组转换为 complex64 的最快方法是啥?

Spark 创建 numpy 数组 RDD 的最快方法