子类化 numpy ndarray 时,如何正确修改 __getitem__?
Posted
技术标签:
【中文标题】子类化 numpy ndarray 时,如何正确修改 __getitem__?【英文标题】:When subclassing a numpy ndarray, how can I modify __getitem__ properly? 【发布时间】:2015-09-25 18:41:44 【问题描述】:我正在尝试继承 numpy 的 ndarray。在我的子类MyClass
中,我添加了一个名为time
的字段作为主数据的并行数组。
我的目标如下:假设我创建了一个 MyClass 的实例,我们称之为mc
。
我切片mc
,例如mc[2:6]
,我希望生成的对象不仅包含正确切片的np数组,还包含相应切片的time
数组。
这是我的尝试:
class MyClass(np.ndarray):
def __new__(cls, data, time=None):
obj = np.asarray(data).view(cls)
obj.time = time
return obj
def __array_finalize__(self, obj):
setattr(self, 'time', obj.time)
def __getitem__(self, item):
#print item #for testing
ret = super(MyClass, self).__getitem__(item)
ret.time = self.time.__getitem__(item)
return ret
这不起作用。经过几个小时的折腾,我意识到这是因为当我调用mc[2:6]
时,__getitem__
实际上被调用了多次。首先,当它被调用时,item
变量正如预期的那样是slice(2,6,None)
。但是随后,包含super(MyClass, self)...
的行再次调用了相同的函数,大概是为了检索切片的各个元素。
问题是它为__getitem__
提供了一组奇怪的参数,总是负数。在 mc[2:6]
的示例中,它又调用了 4 次方法,item
的值分别为 -4、-3、-2 和 -1。
如您所见,这使我无法正确调整 ret.time
变量,因为它会尝试多次修改它,通常使用无意义的索引。
我已尝试通过多种方式解决此问题,包括复制对象并改为编辑该副本、获取对象的各种视图以及许多其他 hack,但似乎没有一个可以克服__getitem__
反复调用的问题与请求的切片不对齐的负索引。
非常感谢您对正在发生的事情的任何帮助或解释。
【问题讨论】:
__new__
中没有self
。
编辑了...这不是问题...对象工作正常,__getitem___
除外
【参考方案1】:
我解决问题的方法(尝试做一些非常相似的事情)如下:
class MyClass(np.ndarray):
...
def __getitem__(self, item):
#print item #for testing
ret = super(MyClass, self).__getitem__(item)
if not isinstance(self, MyClass):
return ret
ret.time = self.time.__getitem__(item)
return ret
这样,如果__getitem__
被多次调用,你只会在调用实例为MyClass
的第一次调用时调整time
方法。
【讨论】:
【参考方案2】:我有一个类似的问题,我以numpy matrix 类为例解决了这个问题。正如您在__array_finalize__
中创建数组之前所注意到的,__getitem__
可以被多次调用。所以解决方案是将潜在的新索引存储在__getitem__
,但设置在__array_finalize__
。
class MyClass(np.ndarray):
def __new__(cls, data, time=None):
obj = np.asarray(data).view(cls)
obj.time = time
return obj
def __array_finalize__(self, obj):
setattr(self, 'time', obj.time)
try:
self.time = self.time[obj._new_time_index]
except:
pass
def __getitem__(self, item):
try:
if isinstance(item, (slice, int)):
self._new_time_index = item
else:
self._new_time_index = item[0]
except:
pass
return super().__getitem__(item)
【讨论】:
【参考方案3】:如果您想在切片上更新time
,请尝试
if isinstance(item, slice):
ret.time = self.time.__getitem__(item)
在您的 __getitem__
方法中。
然后,您的time
-adjusting 代码每次切片仅调用一次,并且在从数组中获取单个项目时永远不会执行。
【讨论】:
以上是关于子类化 numpy ndarray 时,如何正确修改 __getitem__?的主要内容,如果未能解决你的问题,请参考以下文章
TypeError:参数“x”的类型不正确(预期为cupy.core.core.ndarray,得到了numpy.ndarray)