子类化 numpy ndarray 时,如何正确修改 __getitem__?

Posted

技术标签:

【中文标题】子类化 numpy ndarray 时,如何正确修改 __getitem__?【英文标题】:When subclassing a numpy ndarray, how can I modify __getitem__ properly? 【发布时间】:2015-09-25 18:41:44 【问题描述】:

我正在尝试继承 numpy 的 ndarray。在我的子类MyClass 中,我添加了一个名为time 的字段作为主数据的并行数组。

我的目标如下:假设我创建了一个 MyClass 的实例,我们称之为mc。 我切片mc,例如mc[2:6],我希望生成的对象不仅包含正确切片的np数组,还包含相应切片的time数组。

这是我的尝试:

class MyClass(np.ndarray):
    def __new__(cls, data, time=None):
        obj = np.asarray(data).view(cls)
        obj.time = time
        return obj
    def __array_finalize__(self, obj):
        setattr(self, 'time', obj.time)
    def __getitem__(self, item):
        #print item #for testing
        ret = super(MyClass, self).__getitem__(item)
        ret.time = self.time.__getitem__(item)
        return ret

这不起作用。经过几个小时的折腾,我意识到这是因为当我调用mc[2:6] 时,__getitem__ 实际上被调用了多次。首先,当它被调用时,item 变量正如预期的那样是slice(2,6,None)。但是随后,包含super(MyClass, self)... 的行再次调用了相同的函数,大概是为了检索切片的各个元素。

问题是它为__getitem__ 提供了一组奇怪的参数,总是负数。在 mc[2:6] 的示例中,它又调用了 4 次方法,item 的值分别为 -4、-3、-2 和 -1。

如您所见,这使我无法正确调整 ret.time 变量,因为它会尝试多次修改它,通常使用无意义的索引。

我已尝试通过多种方式解决此问题,包括复制对象并改为编辑该副本、获取对象的各种视图以及许多其他 hack,但似乎没有一个可以克服__getitem__ 反复调用的问题与请求的切片不对齐的负索引。

非常感谢您对正在发生的事情的任何帮助或解释。

【问题讨论】:

__new__ 中没有self 编辑了...这不是问题...对象工作正常,__getitem___ 除外 【参考方案1】:

我解决问题的方法(尝试做一些非常相似的事情)如下:

class MyClass(np.ndarray):
    ...

    def __getitem__(self, item):
        #print item #for testing
        ret = super(MyClass, self).__getitem__(item)
        if not isinstance(self, MyClass):
            return ret

        ret.time = self.time.__getitem__(item)
        return ret

这样,如果__getitem__被多次调用,你只会在调用实例为MyClass的第一次调用时调整time方法。

【讨论】:

【参考方案2】:

我有一个类似的问题,我以numpy matrix 类为例解决了这个问题。正如您在__array_finalize__ 中创建数组之前所注意到的,__getitem__ 可以被多次调用。所以解决方案是将潜在的新索引存储在__getitem__,但设置在__array_finalize__

class MyClass(np.ndarray):
    def __new__(cls, data, time=None):
        obj = np.asarray(data).view(cls)
        obj.time = time
        return obj
    def __array_finalize__(self, obj):
        setattr(self, 'time', obj.time)
        try:
            self.time = self.time[obj._new_time_index]
        except:
            pass

    def __getitem__(self, item):
        try:
            if isinstance(item, (slice, int)):
                self._new_time_index = item
            else:
                self._new_time_index = item[0]
        except: 
            pass
        return super().__getitem__(item)

【讨论】:

【参考方案3】:

如果您想在切片上更新time,请尝试

if isinstance(item, slice):
    ret.time = self.time.__getitem__(item)

在您的 __getitem__ 方法中。

然后,您的time-adjusting 代码每次切片仅调用一次,并且在从数组中获取单个项目时永远不会执行。

【讨论】:

以上是关于子类化 numpy ndarray 时,如何正确修改 __getitem__?的主要内容,如果未能解决你的问题,请参考以下文章

如何子类化 CuPy 数组?

一些 Numpy 函数返回 ndarray 而不是我的子类

TypeError:参数“x”的类型不正确(预期为cupy.core.core.ndarray,得到了numpy.ndarray)

Python之numpy详细教程

000.复习大纲

numpy快速入门