如何从multiprocessing.Pool.map的worker_funtion内部为数组赋值?
Posted
技术标签:
【中文标题】如何从multiprocessing.Pool.map的worker_funtion内部为数组赋值?【英文标题】:How to assign values to array from inside the worker_funtion of multiprocessing.Pool.map? 【发布时间】:2021-01-20 06:52:50 【问题描述】:基本上我想要的是将那些2
插入ar
,以便ar
在worker_function
之外更改。
import numpy as np
import multiprocessing as mp
from functools import partial
def worker_function(i=None, ar=None):
val = 2
ar[i] = val
print(ar)
def main():
ar = np.zeros(5)
func_part = partial(worker_function, ar=ar)
mp.Pool(1).map(func_part, range(2))
print(ar)
if __name__ == '__main__':
main()
到目前为止,我唯一能做到的就是在 worker_function
内部而不是在函数外部更改 ar
的副本:
[2. 0. 0. 0. 0.]
[0. 2. 0. 0. 0.]
[0. 0. 0. 0. 0.]
【问题讨论】:
【参考方案1】:首先,您对worker_function
的参数定义顺序错误。
正如您所观察到的,每个进程都会获得数组的副本。你能做的最好的就是返回修改后的数组:
import numpy as np
import multiprocessing as mp
from functools import partial
def worker_function(ar, i): # put the arguments in the correct order!
val = 2
ar[i] = val
#print(ar)
return ar # return modified array
def main():
ar = np.zeros(5)
func_part = partial(worker_function, ar)
arrays = mp.Pool(2).map(func_part, range(2)) # pool size of 2, otherwise what is the point?
for array in arrays:
print(array)
if __name__ == '__main__':
main()
打印:
[2. 0. 0. 0. 0.]
[0. 2. 0. 0. 0.]
但是现在您正在处理两个单独修改的数组。您必须添加额外的逻辑才能将这两个数组的结果合并为一个:
import numpy as np
import multiprocessing as mp
from functools import partial
def worker_function(ar, i): # put the arguments in the correct order!
val = 2
ar[i] = val
#print(ar)
return ar # return modified array
def main():
ar = np.zeros(5)
func_part = partial(worker_function, ar)
arrays = mp.Pool(2).map(func_part, range(2)) # pool size of 2, otherwise what is the point?
for i in range(2):
ar[i] = arrays[i][i]
print(ar)
if __name__ == '__main__':
main()
打印:
[2. 2. 0. 0. 0.]
但更有意义的是worker_function
只返回一个元组,给出被修改元素的索引和新值:
import numpy as np
import multiprocessing as mp
from functools import partial
def worker_function(ar, i): # put the arguments in the correct order!
return i, i + 3 # index, new value
def main():
ar = np.zeros(5)
func_part = partial(worker_function, ar)
results = mp.Pool(2).map(func_part, range(2))
for index, value in results:
ar[index] = value
print(ar)
if __name__ == '__main__':
main()
打印:
[3. 4. 0. 0. 0.]
当然,如果worker_function
修改了多个值,它会返回一个元组的元组。
最后,如果您确实需要将对象传递给子进程,还有另一种使用池初始化器的方法:
import numpy as np
import multiprocessing as mp
def pool_initializer(ar):
global the_array
the_array = ar
def worker_function(i):
return i, the_array[i] ** 2 # index, value
def main():
ar = np.array([1,2,3,4,5])
with mp.Pool(5, pool_initializer, (ar,)) as pool:
results = pool.map(worker_function, range(5))
for index, value in results:
ar[index] = value
print(ar)
if __name__ == '__main__':
main()
打印:
[ 1 4 9 16 25]
【讨论】:
【参考方案2】:为了提高性能,您应该在此处使用共享内存multiprocessing.Array
,以避免一次又一次地在不同进程之间重构和发送数组。该数组在所有进程中都是相同的,在您发送副本的示例中并非如此。这也是您看不到父项中所做更改的原因。
import multiprocessing as mp
import numpy as np
def worker_function(i):
global arr
val = 2
arr[i] = val
print(mp.current_process().name, arr[:])
def init_arr(arr):
globals()['arr'] = arr
def main():
# as long as we don't conditionally modify the same indices
# from multiple workers, we don't need the lock ...
arr = mp.Array('i', np.zeros(5, dtype=int), lock=False)
mp.Pool(2, initializer=init_arr, initargs=(arr,)).map(worker_function, range(5))
print(mp.current_process().name, arr[:])
if __name__ == '__main__':
main()
输出:
ForkPoolWorker-1 [2, 0, 0, 0, 0]
ForkPoolWorker-2 [2, 2, 0, 0, 0]
ForkPoolWorker-1 [2, 2, 2, 0, 0]
ForkPoolWorker-2 [2, 2, 2, 2, 0]
ForkPoolWorker-1 [2, 2, 2, 2, 2]
MainProcess [2, 2, 2, 2, 2]
Process finished with exit code 0
【讨论】:
嘿,我从你的回答中学到了很多。请问我如何向worker_function
添加额外的参数,因为不再有偏函数了?我正在尝试在main()
中添加一个变量x=5
,然后将其传递给worker_function
并打印它。我会在哪里添加 x
作为参数?在init_arr
中添加它似乎不起作用。
@ArturMüllerRomanov 您仍然可以使用functools.partial
作为第二个参数x
,我只是没有使用它,因为它没有必要。但是,如果您有多个参数,您也可以使用Pool.starmap()
而不是Pool.map()
,然后将参数作为元组与.starmap(worker_function, zip(itertools.repeat(x), range(5)))
捆绑并传递。
压缩星图中的参数似乎比使用 functools.partial 更直观。谢谢:-)以上是关于如何从multiprocessing.Pool.map的worker_funtion内部为数组赋值?的主要内容,如果未能解决你的问题,请参考以下文章
如何从外部从 GitHub 加载 JavaScript 文件? [复制]
如何将数据从回收器适配器发送到片段 |如何从 recyclerview 适配器调用片段函数