用地图写入切片引用?
Posted
技术标签:
【中文标题】用地图写入切片引用?【英文标题】:Write into slice references with map? 【发布时间】:2021-09-27 00:52:41 【问题描述】:我正在尝试写入应该通过引用传递给函数的 Python 切片。
def mpfunc(r):
r[:]=1
R=np.zeros((2,4))
mpfunc(R[0])
mpfunc(R[1])
print(R)
此代码按预期工作。 R
现在包含 1
。
但是当我使用map()
时
def mpfunc(r):
r[:]=1
R=np.zeros((2,4))
map(mpfunc,R)
R
似乎R
的切片不再通过引用传递,我从文档中不清楚这一点。 R
现在仍然是 0
。
最终,目标是使用multiprocessin.Pool.map()
,不幸的是,由于同样的原因,它似乎失败了:
from multiprocessing import Pool
def mpfunc(r):
r[:]=1
R=np.zeros((2,4))
with Pool(2) as p:
p.map(mpfunc,R)
print(R)
为什么会这样,我该如何解决?
【问题讨论】:
【参考方案1】:map
(在 Python 3 中)是惰性的,您需要使用它来触发函数,考虑以下简单示例:
def update_dict(dct):
dct.update("x":1)
data = ["x":0,"x":0,"x":0]
mp = map(update_dict, data)
print(data)
lst = list(map(update_dict, data))
print(data)
输出
['x': 0, 'x': 0, 'x': 0]
['x': 1, 'x': 1, 'x': 1]
请记住,如果可能,您应该避免调用 map
以获得副作用,以避免其他人在处理这段代码时感到困惑。
【讨论】:
【参考方案2】:所以在非多处理的情况下你必须迭代map
函数返回的iterable,以确保指定的函数已经应用于所有传递的可迭代对象。但Pool.map
并非如此。
但是你有一个更大的问题。您现在将数组传递给位于不同地址空间中的进程,除非基本 numpy
数组存储在共享内存中,否则无法通过引用来完成。
在以下代码中,每个进程的全局变量R
将使用numpy
数组的共享内存实现进行初始化。现在map
函数将与需要更新的该数组的索引一起使用:
import multiprocessing as mp
import numpy as np
import ctypes
def to_numpy_array(shared_array, shape):
'''Create a numpy array backed by a shared memory Array.'''
arr = np.ctypeslib.as_array(shared_array)
return arr.reshape(shape)
def to_shared_array(arr, ctype):
shared_array = mp.Array(ctype, arr.size, lock=False)
temp = np.frombuffer(shared_array, dtype=arr.dtype)
temp[:] = arr.flatten(order='C')
return shared_array
def init_worker(shared_array, shape):
global R
R = to_numpy_array(shared_array, shape)
def mpfunc(idx):
R[idx, :] = 1
if __name__ == '__main__':
R = np.zeros((2,4))
shape = R.shape
shared_array = to_shared_array(R, ctypes.c_int64)
# you have to now use the shared array as the base
R = to_numpy_array(shared_array, shape)
with mp.Pool(2, initializer=init_worker, initargs=(shared_array, shape)) as p:
p.map(mpfunc, range(shape[0]))
print(R)
打印:
[[1 1 1 1]
[1 1 1 1]]
【讨论】:
好的,现在我明白了。虽然其他两个 cmets 对解释可迭代问题非常有帮助,但我将把这个标记为解决方案,因为它也解决了我的多处理问题。非常感谢大家!【参考方案3】:由于您只是调用了一个 map 函数,它只是为您创建了一个生成器对象,而实际上并没有完成它的调用。生成器是一种延迟或延迟执行的 Python 方式。 所以这是你可以做到的方法之一。
...: def mpfunc(r):
...: r[:]=1
...:
...: R=np.zeros((2,4))
...:
...: # mpfunc(R[0])
...: # mpfunc(R[1])
...: list(map(mpfunc, R))
...:
...: print(R)
只需通过创建列表或任何适合您的方法来使用地图功能。理想情况下,使用next()
函数一个一个地消耗它。
[[1. 1. 1. 1.]
[1. 1. 1. 1.]]
这同样适用于您的多进程 sn-p。
【讨论】:
以上是关于用地图写入切片引用?的主要内容,如果未能解决你的问题,请参考以下文章