如何在类的成员函数上使用 numba?

Posted

技术标签:

【中文标题】如何在类的成员函数上使用 numba?【英文标题】:How do I use numba on a member function of a class? 【发布时间】:2017-06-05 18:20:18 【问题描述】:

我正在使用 Numba 0.30.1 的稳定版本。

我可以这样做:

import numba as nb
@nb.jit("void(f8[:])",nopython=True)                             
def complicated(x):                                  
    for a in x:
        b = a**2.+a**3.

作为一个测试用例,加速是巨大的。但是如果我需要加速类中的函数,我不知道如何进行。

import numba as nb
def myClass(object):
    def __init__(self):
        self.k = 1
    #@nb.jit(???,nopython=True)                             
    def complicated(self,x):                                  
        for a in x:
            b = a**2.+a**3.+self.k

self 对象使用什么 numba 类型?我需要在一个类中拥有这个函数,因为它需要访问一个成员变量。

【问题讨论】:

jitclass 怎么样?鉴于self 的定义是object,我认为不可能避免“object-fallback”。 如果您在下一个循环中立即覆盖它,b = a**2.+a**3.+self.k 将实现什么? self.k 只是示意性地表明我需要调用成员变量,不能只在类外拥有函数 【参考方案1】:

您有多种选择:

使用jitclass (http://numba.pydata.org/numba-doc/0.30.1/user/jitclass.html) 来“麻木”整个事情。

或者把成员函数做成一个包装器,把成员变量通过:

import numba as nb

@nb.jit
def _complicated(x, k):
    for a in x:
        b = a**2.+a**3.+k

class myClass(object):
    def __init__(self):
        self.k = 1

    def complicated(self,x):                                  
        _complicated(x, self.k)

【讨论】:

我在jitclass 上看到过该页面,但我完全不清楚如何明确说明每个成员函数的数据类型。你能举个例子吗?包装函数方法变得不优雅,并且比我一开始就把东西放在一个类中。 我会避免向成员函数声明参数的类型,而只是 Numba 通过类型推断来处理它。在最近的记忆中,我不记得声明类型会带来更好的性能。 github.com/numba/numba/tree/master/examples 中有一些更复杂的 jitclass 示例——例如github.com/numba/numba/blob/master/examples/stack.py @dbrane 我知道类的方法在当前版本的 numba 中不能被 jitted。如果尝试这样做,则会出现非常清晰的错误消息“TypeError:尚不支持类成员:复杂”。此外,我发现 jitclass 的使用不会为您的示例提供任何加速。 我在我正在使用的类的成员函数中声明了辅助函数:def complicated(self, x): def calc(x, k): return some difficult calculation【参考方案2】:

我的情况非常相似,我找到了一种在类中使用 Numba-JITed 函数的方法。

诀窍是使用静态方法,因为在将对象实例添加到参数列表之前不会调用这种方法。无法访问self 的缺点是您不能使用在方法之外定义的变量。因此,您必须将它们从有权访问self 的调用方法传递给静态方法。就我而言,我不需要定义包装方法。我只需要将我想要 JIT 编译的方法拆分为两个方法。

在您的示例中,解决方案是:

from numba import jit

class MyClass:
    def __init__(self):
        self.k = 1

    def calculation(self):
        k = self.k
        return self.complicated([1,2,3],k)

    @staticmethod
    @jit(nopython=True)                             
    def complicated(x,k):                                  
        for a in x:
            b = a**2 .+ a**3 .+ k

【讨论】:

以上是关于如何在类的成员函数上使用 numba?的主要内容,如果未能解决你的问题,请参考以下文章

如何在类的成员函数中调用复制构造函数?

6——在类的外部定义成员函数

为啥我不能在类的成员函数中初始化 QThread?

无法在类的析构函数中删除指向数组的成员指针

在类的成员函数中调用delete this

在类成员函数后面加const