在 python 类中重载 [] 运算符以引用 numpy.array 数据成员

Posted

技术标签:

【中文标题】在 python 类中重载 [] 运算符以引用 numpy.array 数据成员【英文标题】:Overloading the [] operator in python class to refer to a numpy.array data member 【发布时间】:2016-02-26 04:30:47 【问题描述】:

我编写了一个数据容器类,它本质上包含一个 numpy ndarray 成员以及生成 time_series 掩码/横截面掩码的方法,在环形缓冲区模式下获取日期索引(行#),处理调整大小,记住数据可能是一个环形缓冲区,并对形状/尺寸等实施限制。

由于我的类实现,现在我必须通过显式引用 *.data 成员来访问此对象包装的数据。这很麻烦,我想在我的类中实现 [] 运算符,这样当在我的类的实例上调用时,它引用底层 ndarray 对象上的相同操作。我怎样才能做到这一点?

def MyArray(object):
    def __init__(self, shape, fill_value, dtype):
        self.shape = shape
        self.fill_value = fill_value
        self.dtype = dtype
        self.data = numpy.empty(shape, fill_value=fill_value, dtype=dtype)

    def reset(self, fill_value=None):
        self.data.fill(fill_value or self.fill_value)

    def resize(self, shape):
        if self.data.ndim != len(shape): raise Exception("dim error")
        if self.data.shape < shape: raise Exception("sizing down not permitted")
        # do resizing

现在,如果我想在其他地方使用这个容器,我必须这样使用它:

arr = MyArray(shape=(10000,20), fill_value=numpy.nan, dtype='float')
arr.data[::10] = numpy.NAN
msk = numpy.random.randn(10000,20)<.5
arr.data[~msk] = -1.

我每次使用时都需要显式引用 arr.data 太麻烦且容易出错(我在很多地方都忘记了 .data 后缀)。

我有什么办法可以添加一些运算符,以便 arr 上的切片和索引实际上在 arr.data 上运行隐式

【问题讨论】:

你必须实现__getitem__属性 这适用于在 LHS 和 RHS 上使用 arr 的情况吗? @Mindstorm 如果你分配给它,那就是__setitem__。只需将参数传递给self.data self.data.fill(fill_value or self.fill_value) - 如果有人想用零填充它怎么办? 感谢您指出这一点。我忽略了这一点! :| 【参考方案1】:

您需要实现__getitem____setitem__ 魔术函数。

魔术方法的完整概述可以找到here。

import numpy as np

class MyArray():
    def __init__(self):
        self.data = np.zeros(10)

    def __getitem__(self, key):
        return self.data[key]

    def __setitem__(self, key, value):
        self.data[key] = value

    def __repr__(self):
        return 'MyArray()'.format(self.data)


a = MyArray()

print(a[9])
print(a[1:5])
a[:] = np.arange(10)
print(a)

这会给你这个结果:

0.0
[ 0.  0.  0.  0.]
MyArray([ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9.])

继承

如果你只想修改或添加 np.ndarray 的行为,你可以继承它。这比普通的 python 类要复杂一些,但是实现你的案例应该不是那么难:

import numpy as np


class MyArray(np.ndarray):

    def __new__(cls, shape, fill_value=0, dtype=float):
        data = np.full(shape, fill_value, dtype)
        obj = np.asarray(data).view(cls)
        obj.fill_value = fill_value
        return obj

    def reset(self, fill_value=None):
        if fill_value is not None:
            self.fill_value = fill_value

        self.fill(self.fill_value)

有关详细信息,请参阅here。

【讨论】:

优秀。但是有一个问题:我是否必须写下每种方法才能将 arr 上的 [] 调用完全委托给 arr.data? a+= b(iadd) 有一种方法,a*= b(imul) 有一种方法。难道没有比在我的包装类中枚举所有这些方法更简洁的方法了吗? 您可以从数组继承并实现或覆盖您需要的方法。 但是,这有点复杂,但在这里深入处理:docs.scipy.org/doc/numpy/user/basics.subclassing.html 我在答案中添加了继承解决方案。 当我尝试调整大小时,我将该类作为 ndarray 子类的实现遇到了问题:&gt;&gt;&gt; a.resize((10,2)) Traceback (most recent call last): File "&lt;stdin&gt;", line 1, in &lt;module&gt; ValueError: cannot resize an array references or is referenced by another array in this way. Use the resize function. &gt;&gt;&gt; a.view(np.ndarray).resize((10,2)) Traceback (most recent call last): File "&lt;stdin&gt;", line 1, in &lt;module&gt; ValueError: cannot resize this array: it does not own its data

以上是关于在 python 类中重载 [] 运算符以引用 numpy.array 数据成员的主要内容,如果未能解决你的问题,请参考以下文章

C++,如何在派生类中调用基类的重载提取运算符?

Python面向对象之运算符重载

在 Python 中使用重载加法运算符时出现内存错误

C++:-> 运算符重载:以不同方式处理 const / nonconst 访问

尝试在模板类中重载 / 运算符的 C++ 错误

python-重载