在 python 类中重载 [] 运算符以引用 numpy.array 数据成员
Posted
技术标签:
【中文标题】在 python 类中重载 [] 运算符以引用 numpy.array 数据成员【英文标题】:Overloading the [] operator in python class to refer to a numpy.array data member 【发布时间】:2016-02-26 04:30:47 【问题描述】:我编写了一个数据容器类,它本质上包含一个 numpy ndarray 成员以及生成 time_series 掩码/横截面掩码的方法,在环形缓冲区模式下获取日期索引(行#),处理调整大小,记住数据可能是一个环形缓冲区,并对形状/尺寸等实施限制。
由于我的类实现,现在我必须通过显式引用 *.data 成员来访问此对象包装的数据。这很麻烦,我想在我的类中实现 [] 运算符,这样当在我的类的实例上调用时,它引用底层 ndarray 对象上的相同操作。我怎样才能做到这一点?
def MyArray(object):
def __init__(self, shape, fill_value, dtype):
self.shape = shape
self.fill_value = fill_value
self.dtype = dtype
self.data = numpy.empty(shape, fill_value=fill_value, dtype=dtype)
def reset(self, fill_value=None):
self.data.fill(fill_value or self.fill_value)
def resize(self, shape):
if self.data.ndim != len(shape): raise Exception("dim error")
if self.data.shape < shape: raise Exception("sizing down not permitted")
# do resizing
现在,如果我想在其他地方使用这个容器,我必须这样使用它:
arr = MyArray(shape=(10000,20), fill_value=numpy.nan, dtype='float')
arr.data[::10] = numpy.NAN
msk = numpy.random.randn(10000,20)<.5
arr.data[~msk] = -1.
我每次使用时都需要显式引用 arr.data 太麻烦且容易出错(我在很多地方都忘记了 .data 后缀)。
我有什么办法可以添加一些运算符,以便 arr
上的切片和索引实际上在 arr.data
上运行隐式?
【问题讨论】:
你必须实现__getitem__
属性
这适用于在 LHS 和 RHS 上使用 arr 的情况吗?
@Mindstorm 如果你分配给它,那就是__setitem__
。只需将参数传递给self.data
。
self.data.fill(fill_value or self.fill_value)
- 如果有人想用零填充它怎么办?
感谢您指出这一点。我忽略了这一点! :|
【参考方案1】:
您需要实现__getitem__
和__setitem__
魔术函数。
魔术方法的完整概述可以找到here。
import numpy as np
class MyArray():
def __init__(self):
self.data = np.zeros(10)
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
def __repr__(self):
return 'MyArray()'.format(self.data)
a = MyArray()
print(a[9])
print(a[1:5])
a[:] = np.arange(10)
print(a)
这会给你这个结果:
0.0
[ 0. 0. 0. 0.]
MyArray([ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9.])
继承
如果你只想修改或添加 np.ndarray 的行为,你可以继承它。这比普通的 python 类要复杂一些,但是实现你的案例应该不是那么难:
import numpy as np
class MyArray(np.ndarray):
def __new__(cls, shape, fill_value=0, dtype=float):
data = np.full(shape, fill_value, dtype)
obj = np.asarray(data).view(cls)
obj.fill_value = fill_value
return obj
def reset(self, fill_value=None):
if fill_value is not None:
self.fill_value = fill_value
self.fill(self.fill_value)
有关详细信息,请参阅here。
【讨论】:
优秀。但是有一个问题:我是否必须写下每种方法才能将 arr 上的 [] 调用完全委托给 arr.data? a+= b(iadd) 有一种方法,a*= b(imul) 有一种方法。难道没有比在我的包装类中枚举所有这些方法更简洁的方法了吗? 您可以从数组继承并实现或覆盖您需要的方法。 但是,这有点复杂,但在这里深入处理:docs.scipy.org/doc/numpy/user/basics.subclassing.html 我在答案中添加了继承解决方案。 当我尝试调整大小时,我将该类作为 ndarray 子类的实现遇到了问题:>>> a.resize((10,2)) Traceback (most recent call last): File "<stdin>", line 1, in <module> ValueError: cannot resize an array references or is referenced by another array in this way. Use the resize function. >>> a.view(np.ndarray).resize((10,2)) Traceback (most recent call last): File "<stdin>", line 1, in <module> ValueError: cannot resize this array: it does not own its data
以上是关于在 python 类中重载 [] 运算符以引用 numpy.array 数据成员的主要内容,如果未能解决你的问题,请参考以下文章