如何从multiprocessing.Pool.map的worker_funtion内部为数组赋值?

Posted

技术标签:

【中文标题】如何从multiprocessing.Pool.map的worker_funtion内部为数组赋值?【英文标题】:How to assign values to array from inside the worker_funtion of multiprocessing.Pool.map? 【发布时间】:2021-01-20 06:52:50 【问题描述】:

基本上我想要的是将那些2 插入ar,以便arworker_function 之外更改。

import numpy as np
import multiprocessing as mp
from functools import partial


def worker_function(i=None, ar=None):
    val = 2
    ar[i] = val
    print(ar)


def main():
    ar = np.zeros(5)
    func_part = partial(worker_function, ar=ar)
    mp.Pool(1).map(func_part, range(2))
    print(ar)


if __name__ == '__main__':
    main()

到目前为止,我唯一能做到的就是在 worker_function 内部而不是在函数外部更改 ar 的副本:

[2. 0. 0. 0. 0.]
[0. 2. 0. 0. 0.]
[0. 0. 0. 0. 0.]

【问题讨论】:

【参考方案1】:

首先,您对worker_function 的参数定义顺序错误。

正如您所观察到的,每个进程都会获得数组的副本。你能做的最好的就是返回修改后的数组:

import numpy as np
import multiprocessing as mp
from functools import partial


def worker_function(ar, i): # put the arguments in the correct order!
    val = 2
    ar[i] = val
    #print(ar)
    return ar # return modified array


def main():
    ar = np.zeros(5)
    func_part = partial(worker_function, ar)
    arrays = mp.Pool(2).map(func_part, range(2)) # pool size of 2, otherwise what is the point?
    for array in arrays:
        print(array)


if __name__ == '__main__':
    main()

打印:

[2. 0. 0. 0. 0.]
[0. 2. 0. 0. 0.]

但是现在您正在处理两个单独修改的数组。您必须添加额外的逻辑才能将这两个数组的结果合并为一个:

import numpy as np
import multiprocessing as mp
from functools import partial


def worker_function(ar, i): # put the arguments in the correct order!
    val = 2
    ar[i] = val
    #print(ar)
    return ar # return modified array


def main():
    ar = np.zeros(5)
    func_part = partial(worker_function, ar)
    arrays = mp.Pool(2).map(func_part, range(2)) # pool size of 2, otherwise what is the point?
    for i in range(2):
        ar[i] = arrays[i][i]
    print(ar)


if __name__ == '__main__':
    main()

打印:

[2. 2. 0. 0. 0.]

但更有意义的是worker_function 只返回一个元组,给出被修改元素的索引和新值:

import numpy as np
import multiprocessing as mp
from functools import partial


def worker_function(ar, i): # put the arguments in the correct order!
    return i, i + 3 # index, new value


def main():
    ar = np.zeros(5)
    func_part = partial(worker_function, ar)
    results = mp.Pool(2).map(func_part, range(2))
    for index, value in results:
        ar[index] = value
    print(ar)


if __name__ == '__main__':
    main()

打印:

[3. 4. 0. 0. 0.]

当然,如果worker_function修改了多个值,它会返回一个元组的元组。

最后,如果您确实需要将对象传递给子进程,还有另一种使用池初始化器的方法:

import numpy as np
import multiprocessing as mp


def pool_initializer(ar):
    global the_array

    the_array = ar


def worker_function(i):
    return i, the_array[i] ** 2 # index, value


def main():
    ar = np.array([1,2,3,4,5])
    with mp.Pool(5, pool_initializer, (ar,)) as pool:
        results = pool.map(worker_function, range(5))
    for index, value in results:
        ar[index] = value
    print(ar)


if __name__ == '__main__':
    main()

打印:

[ 1  4  9 16 25]

【讨论】:

【参考方案2】:

为了提高性能,您应该在此处使用共享内存multiprocessing.Array,以避免一次又一次地在不同进程之间重构和发送数组。该数组在所有进程中都是相同的,在您发送副本的示例中并非如此。这也是您看不到父项中所做更改的原因。

import multiprocessing as mp
import numpy as np


def worker_function(i):
    global arr
    val = 2
    arr[i] = val
    print(mp.current_process().name, arr[:])


def init_arr(arr):
    globals()['arr'] = arr


def main():
    # as long as we don't conditionally modify the same indices 
    # from multiple workers, we don't need the lock ...
    arr = mp.Array('i', np.zeros(5, dtype=int), lock=False)
    mp.Pool(2, initializer=init_arr, initargs=(arr,)).map(worker_function, range(5))
    print(mp.current_process().name, arr[:])


if __name__ == '__main__':
    main()

输出:

ForkPoolWorker-1 [2, 0, 0, 0, 0]
ForkPoolWorker-2 [2, 2, 0, 0, 0]
ForkPoolWorker-1 [2, 2, 2, 0, 0]
ForkPoolWorker-2 [2, 2, 2, 2, 0]
ForkPoolWorker-1 [2, 2, 2, 2, 2]
MainProcess [2, 2, 2, 2, 2]

Process finished with exit code 0

【讨论】:

嘿,我从你的回答中学到了很多。请问我如何向worker_function 添加额外的参数,因为不再有偏函数了?我正在尝试在main() 中添加一个变量x=5,然后将其传递给worker_function 并打印它。我会在哪里添加 x 作为参数?在init_arr 中添加它似乎不起作用。 @ArturMüllerRomanov 您仍然可以使用functools.partial 作为第二个参数x,我只是没有使用它,因为它没有必要。但是,如果您有多个参数,您也可以使用Pool.starmap() 而不是Pool.map(),然后将参数作为元组与.starmap(worker_function, zip(itertools.repeat(x), range(5))) 捆绑并传递。 压缩星图中的参数似乎比使用 functools.partial 更直观。谢谢:-)

以上是关于如何从multiprocessing.Pool.map的worker_funtion内部为数组赋值?的主要内容,如果未能解决你的问题,请参考以下文章

如何从外部从 GitHub 加载 JavaScript 文件? [复制]

如何将数据从回收器适配器发送到片段 |如何从 recyclerview 适配器调用片段函数

如何从 Firebase 获取所有设备令牌?

如何直接从类调用从接口继承的方法?

如何从服务器获取和设置 android 中的 API(从服务器获取 int 值)?如何绑定和实现这个

如何从Mac从android studio中的fabric注销? [复制]