序列化 Numpy 数组的意外行为

Posted

技术标签:

【中文标题】序列化 Numpy 数组的意外行为【英文标题】:Unexpected behavior serializing Numpy array 【发布时间】:2016-02-02 10:07:22 【问题描述】:

代码

假设我有:

import numpy
import pickle


class Test():
    def __init__(self):
        self.base = numpy.zeros(6)
        self.view = self.base[-3:]

    def __len__(self):
        return len(self.view)

    def update(self):
        self.view[0] += 1

    def add(self):
        self.view = self.base[-len(self.view) - 1:]
        self.view[0] = 1

    def __repr__(self):
        return str(self.view)


def serialize_data():
    data = Test()
    return pickle.dumps(data)

请注意,Test 类只是一个包含 NumPy 数组 baseview 的类。这个view 只是基中最后一个N 元素的一部分(初始化时N == 3)。

Test 有一个方法update()1 添加到视图位置0 的值,以及一个方法add() 修改视图大小(N = N + 1)并将值设置为将0 定位到1

函数serialize_data 只是创建一个Test() 实例,然后使用pickle 返回序列化对象。

行为

如果我创建一个局部变量并 update 它两次和 add 它一次,一切都按预期工作:

# Local variable
test = Test()
print(test)    # [ 0.  0.  0.]

test.update()
test.update()
print(test)    # [ 2.  0.  0.]

test.add()
print(test)    # [ 1.  2.  0.  0.]

现在,如果我从序列化数据中创建一个局部变量,那么在执行add 之后,值2(在调用update 两次后设置)似乎丢失了:

# Serialized variable
data = pickle.loads(serialize_data())
print(data)    # [ 0.  0.  0.]

data.update()
data.update()
print(data)    # [ 2.  0.  0.]

data.add()
print(data)    # [ 1.  0.  0.  0.]  <----  This should be [ 1. 2. 0. 0. ] !!!

问题

为什么会发生这种情况,我该如何避免这种行为?

【问题讨论】:

问题在于,在酸洗/去酸洗之后,视图不再是基础视图,而是拥有自己的数据副本。 see here,不幸的是,没有关于如何防止这种情况的答案。 @kazemakase:有了这些信息,我可以解决我特定用例的问题。我将尝试实施它并用解决方案回答我自己的问题(以防将来对其他人有效)。谢谢! :-) PS:请考虑添加您的答案,以便我接受。 我认为我的评论有点微不足道,无法获得答案。但是,我找到了针对您的特定问题的解决方法,我会在一分钟内发布:) @kazemakase:我觉得已经够好了! :-D 使用__getstate__ and __setstate__ 参考链接当然更好。但是,实际实现取决于用例(我发布的案例不是我正在使用的真实案例)。我将在接下来的 24 小时内接受您的答复。我喜欢将问题保留几个小时,以防其他人以不同的方法介入。 ;-) 再次感谢! 别担心接受。这是一个有趣的问题,我想知道是否有其他方法可以解决它。因此,如果您有不同的解决方案,请记住发布您自己的解决方案 :) 【参考方案1】:

问题在于,在酸洗/取消酸洗之后,视图不再是基础视图,而是拥有自己的数据副本。 See here,不幸的是,没有关于如何防止这种情况的答案。

可以通过为类定义 __getstate__ and __setstate__ 方法来解决特定问题,这些方法在 unpickling 后重新定义视图。

除了视图之外,还需要跟踪视图所查看的基础部分。我选择使用切片对象,但还有其他方法。不需要对视图本身进行腌制,因为它会在取消腌制时从切片中重建。

class Test():
    def __init__(self):
        self.base = numpy.zeros(6)
        self.slice = slice(-3, self.base.size)
        self.view = self.base[self.slice]

    def __len__(self):
        return len(self.view)

    def update(self):
        self.view[0] += 1

    def add(self):
        self.slice = slice(-len(self.view) - 1, self.base.size)
        self.view = self.base[self.slice]        
        self.view[0] = 1

    def __getstate__(self):
        return 'base': self.base, 'slice': self.slice

    def __setstate__(self, state):
        self.base = state['base']
        self.slice = state['slice']
        self.view = self.base[self.slice]

    def __repr__(self):
        return str(self.view)

【讨论】:

以上是关于序列化 Numpy 数组的意外行为的主要内容,如果未能解决你的问题,请参考以下文章

意外的 protobuf-net 序列化程序行为

JSON - 使用 numpy 数组条目序列化 pandas 数据帧

使 numpy 数组 JSON 可序列化

序列化 numpy 数组列表并读回/反序列化为 Javascript

SimpleJSON 和 NumPy 数组

为包含可变长度序列的数组的输出标签创建分类 numpy 数组