序列化 Numpy 数组的意外行为
Posted
技术标签:
【中文标题】序列化 Numpy 数组的意外行为【英文标题】:Unexpected behavior serializing Numpy array 【发布时间】:2016-02-02 10:07:22 【问题描述】:代码
假设我有:
import numpy
import pickle
class Test():
def __init__(self):
self.base = numpy.zeros(6)
self.view = self.base[-3:]
def __len__(self):
return len(self.view)
def update(self):
self.view[0] += 1
def add(self):
self.view = self.base[-len(self.view) - 1:]
self.view[0] = 1
def __repr__(self):
return str(self.view)
def serialize_data():
data = Test()
return pickle.dumps(data)
请注意,Test
类只是一个包含 NumPy 数组 base
的 view
的类。这个view
只是基中最后一个N
元素的一部分(初始化时N == 3
)。
Test
有一个方法update()
将1
添加到视图位置0
的值,以及一个方法add()
修改视图大小(N = N + 1
)并将值设置为将0
定位到1
。
函数serialize_data
只是创建一个Test()
实例,然后使用pickle
返回序列化对象。
行为
如果我创建一个局部变量并 update
它两次和 add
它一次,一切都按预期工作:
# Local variable
test = Test()
print(test) # [ 0. 0. 0.]
test.update()
test.update()
print(test) # [ 2. 0. 0.]
test.add()
print(test) # [ 1. 2. 0. 0.]
现在,如果我从序列化数据中创建一个局部变量,那么在执行add
之后,值2
(在调用update
两次后设置)似乎丢失了:
# Serialized variable
data = pickle.loads(serialize_data())
print(data) # [ 0. 0. 0.]
data.update()
data.update()
print(data) # [ 2. 0. 0.]
data.add()
print(data) # [ 1. 0. 0. 0.] <---- This should be [ 1. 2. 0. 0. ] !!!
问题
为什么会发生这种情况,我该如何避免这种行为?
【问题讨论】:
问题在于,在酸洗/去酸洗之后,视图不再是基础视图,而是拥有自己的数据副本。 see here,不幸的是,没有关于如何防止这种情况的答案。 @kazemakase:有了这些信息,我可以解决我特定用例的问题。我将尝试实施它并用解决方案回答我自己的问题(以防将来对其他人有效)。谢谢! :-) PS:请考虑添加您的答案,以便我接受。 我认为我的评论有点微不足道,无法获得答案。但是,我找到了针对您的特定问题的解决方法,我会在一分钟内发布:) @kazemakase:我觉得已经够好了! :-D 使用__getstate__ and __setstate__
参考链接当然更好。但是,实际实现取决于用例(我发布的案例不是我正在使用的真实案例)。我将在接下来的 24 小时内接受您的答复。我喜欢将问题保留几个小时,以防其他人以不同的方法介入。 ;-) 再次感谢!
别担心接受。这是一个有趣的问题,我想知道是否有其他方法可以解决它。因此,如果您有不同的解决方案,请记住发布您自己的解决方案 :)
【参考方案1】:
问题在于,在酸洗/取消酸洗之后,视图不再是基础视图,而是拥有自己的数据副本。 See here,不幸的是,没有关于如何防止这种情况的答案。
可以通过为类定义 __getstate__
and __setstate__
方法来解决特定问题,这些方法在 unpickling 后重新定义视图。
除了视图之外,还需要跟踪视图所查看的基础部分。我选择使用切片对象,但还有其他方法。不需要对视图本身进行腌制,因为它会在取消腌制时从切片中重建。
class Test():
def __init__(self):
self.base = numpy.zeros(6)
self.slice = slice(-3, self.base.size)
self.view = self.base[self.slice]
def __len__(self):
return len(self.view)
def update(self):
self.view[0] += 1
def add(self):
self.slice = slice(-len(self.view) - 1, self.base.size)
self.view = self.base[self.slice]
self.view[0] = 1
def __getstate__(self):
return 'base': self.base, 'slice': self.slice
def __setstate__(self, state):
self.base = state['base']
self.slice = state['slice']
self.view = self.base[self.slice]
def __repr__(self):
return str(self.view)
【讨论】:
以上是关于序列化 Numpy 数组的意外行为的主要内容,如果未能解决你的问题,请参考以下文章
JSON - 使用 numpy 数组条目序列化 pandas 数据帧