Numba 代码比纯 python 慢

Posted

技术标签:

【中文标题】Numba 代码比纯 python 慢【英文标题】:Numba code slower than pure python 【发布时间】:2014-02-23 10:41:15 【问题描述】:

我一直致力于加快粒子过滤器的重采样计算。由于 python 有很多方法可以加快速度,所以我会尝试所有方法。不幸的是,numba 版本非常慢。由于 Numba 应该会加快速度,因此我认为这是我的错误。

我尝试了 4 个不同的版本:

    Numba Python 麻木 赛通

每个代码如下:

import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample

@nb.autojit
def numba_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def python_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def numpy_resample(qs, xs, rands):
    results = np.empty_like(qs)
    lookup = sp.cumsum(qs)
    for j, key in enumerate(rands):
        i = sp.argmax(lookup>key)
        results[j] = xs[i]
    return results

#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython

DTYPE = np.float64

ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs, 
             np.ndarray[DTYPE_t, ndim=1] xs, 
             np.ndarray[DTYPE_t, ndim=1] rands):
    if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
        raise ValueError("Arrays must have same shape")
    assert qs.dtype == xs.dtype == rands.dtype == DTYPE

    cdef unsigned int n = qs.shape[0]
    cdef unsigned int i, j 
    cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
    cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results
"""

if __name__ == '__main__':
    n = 100
    xs = np.arange(n, dtype=np.float64)
    qs = np.array([1.0/n,]*n)
    rands = np.random.rand(n)

    print "Timing Numba Function:"
    %timeit numba_resample(qs, xs, rands)
    print "Timing Python Function:"
    %timeit python_resample(qs, xs, rands)
    print "Timing Numpy Function:"
    %timeit numpy_resample(qs, xs, rands)
    print "Timing Cython Function:"
    %timeit cython_resample(qs, xs, rands)

这会产生以下输出:

Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop

知道为什么 numba 代码这么慢吗?我认为它至少可以与 Numpy 相媲美。

注意:如果有人对如何加速 Numpy 或 Cython 代码示例有任何想法,那也很好:)我的主要问题是关于 Numba。

【问题讨论】:

我认为更好的地方是codereview.stackexchange.com 用更大的列表试试? @IanAuld:也许吧,但由于其他人已经从 numba 获得了显着的加速,我认为这是我用错了,而不仅仅是一个分析问题。在我看来,这符合 *** 的预期用途。 @JoranBeasley:我尝试了 1000 和 10000 分。 Numba 运行 1000 需要 773 毫秒,而纯 python 需要 234 毫秒。 10000点试炼还在进行中…… 请注意argmax 可以接受一个轴参数,因此您可以相互广播randslookup 以制作一个用于N^2 缩放算法的n x n 矩阵。或者,您可以使用 searchsorted 将具有(应该有?)Nlog(N) 缩放。 【参考方案1】:

问题是 numba 无法直觉lookup 的类型。如果您在方法中添加print nb.typeof(lookup),您会看到 numba 将其视为对象,这很慢。通常我只会在本地字典中定义lookup 的类型,但我遇到了一个奇怪的错误。相反,我只是创建了一个小包装器,以便我可以显式定义输入和输出类型。

@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
    return np.cumsum(x)

@nb.autojit
def numba_resample2(qs, xs, rands):
    n = qs.shape[0]
    #lookup = np.cumsum(qs)
    lookup = numba_cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

那么我的时间是:

print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)

print "Timing Revised Numba Function:"
%timeit numba_resample2(qs, xs, rands)

Timing Numba Function:
100 loops, best of 3: 8.1 ms per loop
Timing Revised Numba Function:
100000 loops, best of 3: 15.3 µs per loop

如果你使用jit而不是autojit,你甚至可以更快一点:

@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))

对我来说,它从 15.3 微秒降低到 12.5 微秒,但 autojit 的表现仍然令人印象深刻。

【讨论】:

是的,解决了!我尝试在 numba_cumsum 函数上展开循环,并对其进行 jit-ing,但它要么运行速度较慢,要么无法编译。看起来这是最快的速度。对我来说奇怪的是,numba 版本现在运行一致~是 cython 代码的两倍。由于它们都是编译的,我觉得这很奇怪。想法? @jammycrisp - 我还尝试手动编码 cumsum,我发现它比调用 numpy 稍微慢一些。至于 cython 和 numba 之间的差异,它可能与您使用的任何 c 编译器 vs llvm 有关。你用的是什么编译器?您是否在 setup.py 中指定了任何优化标志? 我使用的是 GCC 4.6.3。我不知道你可以在 setup.py 中添加编译器标志,但是在弄清楚之后我用 -O3 编译,它似乎没有改变任何东西。【参考方案2】:

更快的numpy 版本(与numpy_resample 相比加速10 倍)

def numpy_faster(qs, xs, rands):
    lookup = np.cumsum(qs)
    mm = lookup[None,:]>rands[:,None]
    I = np.argmax(mm,1)
    return xs[I]

【讨论】:

谢谢。我想有一种方法可以做到这一点,但在跳到 cython 之前并没有过多地研究它。对于 n=100,我只能从旧的 numpy 函数中获得 2 倍的加速,但很高兴知道。仍然很好奇为什么我的 numba 代码不起作用。

以上是关于Numba 代码比纯 python 慢的主要内容,如果未能解决你的问题,请参考以下文章

为啥这个 numba 代码比 numpy 代码慢 6 倍?

一行代码实现Python运行性能增强百倍,性能发动机numba模块介绍

一行代码实现Python运行性能增强百倍,性能发动机numba模块介绍

分配给数组时Numba慢吗?

为啥同时使用 numba.cuda 和 CuPy 从 GPU 传输数据这么慢?

让你python代码更快的3个小技巧