在扩展 numbers.Real 的对象上使用 numpy 函数

Posted

技术标签:

【中文标题】在扩展 numbers.Real 的对象上使用 numpy 函数【英文标题】:Use numpy functions on objects extending numbers.Real 【发布时间】:2022-01-16 01:15:04 【问题描述】:

PEP 3141 为不同类型的数字引入了抽象基类,以允许自定义实现。 我想从numbers.Real 派生一个类并计算它的正弦值。使用 pythons math-module,这工作正常。当我在 numpy 中尝试相同的操作时,出现错误。

from numbers import Real
import numpy as np
import math

class Mynum(Real):
    def __float__(self):
        return 0.0
    # Many other definitions

a = Mynum()

print("math:")
print(math.sin(a))
print("numpy:")
print(np.sin(a))

结果

math:
0.0
numpy:
AttributeError: 'Mynum' object has no attribute 'sin'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
[...]
in <module>
    print(np.sin(a)) TypeError: loop of ufunc does not support argument 0 of type Mynum which has no callable sin method

似乎 numpy 试图调用其参数的 sin-method。对我来说,这很令人困惑,因为标准数据类型(如 float)也没有这样的方法,但 np.sin 可以处理它们。

是否只是对标准数据类型进行某种硬编码检查,不支持 PEP 3141?还是我在课堂上错过了什么?

因为实现所有必需的方法非常繁琐,所以这是我当前与math-module 一起使用的代码:

from numbers import Real
import numpy as np
import math

class Mynum(Real):
    def __init__(self):
        pass

    def __abs__(self):
        pass

    def __add__(self):
        pass

    def __ceil__(self):
        pass

    def __eq__(self):
        pass

    def __float__(self):
        return 0.0

    def __floor__(self):
        pass

    def __floordiv__(self):
        pass

    def __le__(self):
        pass

    def __lt__(self):
        pass

    def __mod__(self):
        pass

    def __mul__(self):
        pass

    def __neg__(self):
        pass

    def __pos__(self):
        pass

    def __pow__(self):
        pass

    def __radd__(self):
        pass

    def __rfloordiv__(self):
        pass

    def __rmod__(self):
        pass

    def __rmul__(self):
        pass

    def __round__(self):
        pass

    def __rpow__(self):
        pass

    def __rtruediv__(self):
        pass

    def __truediv__(self):
        pass

    def __trunc__(self):
        pass

a = Mynum()
print("math:")
print(math.sin(a))
print("numpy:")
print(np.sin(a))

【问题讨论】:

该 PEP 不适用于 numoy's dtypes。 【参考方案1】:

我刚刚回答了这样的问题,但我会重复一遍

np.sin(a)

其实是

np.sin(np.array(a))

np.array(a) 产生什么? dtype 是什么?

如果它是一个对象 dtype 数组,则说明该错误。使用 object dtype 数组,numpy 遍历(引用),并尝试在每个上运行适当的方法。对于可以使用类似__add__ 方法的操作员来说,这通常没问题,但几乎没有人定义sinexp 方法。

从昨天开始

How can I make my class more robust to operator/function overloading?

比较数字 dtype 数组和对象 dtype:

In [428]: np.sin(np.array([1,2,3]))
Out[428]: array([0.84147098, 0.90929743, 0.14112001])

In [429]: np.sin(np.array([1,2,3], object))
AttributeError: 'int' object has no attribute 'sin'

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "<ipython-input-429-d6927b9a87c7>", line 1, in <module>
    np.sin(np.array([1,2,3], object))
TypeError: loop of ufunc does not support argument 0 of type int which has no callable sin method

【讨论】:

引用的问题真的很有帮助,尤其是以下描述如何创建自定义数组类型的文档条目:numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch。有关对我有用的更多详细信息,请参阅下面我自己的答案。【参考方案2】:

请参阅 hpaulj 的回答(以及 linked question)了解为什么这不起作用。

阅读文档后,我选择创建一个custom numpy array container 并添加我自己的numpy ufunc 支持。相关方法是

def __array_ufunc__(self, ufunc, method, *args, **kwargs):
    if method == "__call__":
        scalars = []
        for arg in args:
            # CAUTION: order matters here because Mynum is also a number
            if isinstance(arg, self.__class__):
                scalars.append(arg.value)
            elif isinstance(arg, Number):
                scalars.append(arg)
            else:
                return NotImplemented
        return self.__class__(ufunc(*scalars, **kwargs))
    return NotImplemented

我选择只支持我自己的数据类型和 ufunc 的 numbers.Number,这使得实现非常简单。 有关详细信息,请参阅docs。

为了扩展numbers.Real,我们还需要定义各种魔术方法(见PEP 3141)。 通过扩展np.lib.mixins.NDArrayOperatorsMixin(除了numbers.Real),我们可以免费获得大部分。 其余的需要手动实现。

您可以在下面看到我的完整代码,它适用于math-module 函数以及 numpys。

from numbers import Real, Number
import numpy as np
import math


class Mynum(np.lib.mixins.NDArrayOperatorsMixin, Real):
    def __init__(self, value):
        self.value = value

    def __repr__(self):
        return f"self.__class__.__name__(value=self.value)"

    def __array__(self, dtype=None):
        return np.array(self.value, dtype=dtype)

    def __array_ufunc__(self, ufunc, method, *args, **kwargs):
        if method == "__call__":
            scalars = []
            for arg in args:
                # CAUTION: order matters here because Mynum is also a number
                if isinstance(arg, self.__class__):
                    scalars.append(arg.value)
                elif isinstance(arg, Number):
                    scalars.append(arg)
                else:
                    return NotImplemented
            return self.__class__(ufunc(*scalars, **kwargs))
        return NotImplemented

    # Below methods are needed because we are extending numbers.Real
    # NDArrayOperatorsMixin takes care of the remaining magic functions

    def __float__(self, *args, **kwargs):
        return self.value.__float__(*args, **kwargs)

    def __ceil__(self, *args, **kwargs):
        return self.value.__ceil__(*args, **kwargs)

    def __floor__(self, *args, **kwargs):
        return self.value.__floor__(*args, **kwargs)

    def __round__(self, *args, **kwargs):
        return self.value.__round__(*args, **kwargs)

    def __trunc__(self, *args, **kwargs):
        return self.value.__trunc__(*args, **kwargs)


a = Mynum(0)

print("math:")
print(math.sin(a))
print("numpy:")
print(np.sin(a))

【讨论】:

以上是关于在扩展 numbers.Real 的对象上使用 numpy 函数的主要内容,如果未能解决你的问题,请参考以下文章

在扩展的 BasicDataSource 对象上使用 JNDI

在对象数组上使用扩展运算符来访问元素[重复]

在自定义键盘扩展中检测输入对象视图类型

多线程环境中的扩展 OVERLAPPED 对象池:在何处以及如何有效地使用锁定

PHP使用mysqli扩展库实现增删改查(面向对象版)

jsonp