如何在 python joblib 中写入共享变量

Posted

技术标签:

【中文标题】如何在 python joblib 中写入共享变量【英文标题】:How to write to a shared variable in python joblib 【发布时间】:2018-03-21 08:07:01 【问题描述】:

以下代码并行化了一个 for 循环。

import networkx as nx;
import numpy as np;
from joblib import Parallel, delayed;
import multiprocessing;

def core_func(repeat_index, G, numpy_arrary_2D):
  for u in G.nodes():
    numpy_arrary_2D[repeat_index][u] = 2;
  return;

if __name__ == "__main__":
  G = nx.erdos_renyi_graph(100000,0.99);
  nRepeat = 5000;
  numpy_array = np.zeros([nRepeat,G.number_of_nodes()]);
  Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));
  print(np.mean(numpy_array));

可以看出,要打印的预期值为 2。但是,当我在集群(多核、共享内存)上运行我的代码时,它返回 0.0。

我认为问题在于每个工作人员都创建了自己的numpy_array 对象副本,而在主函数中创建的副本没有更新。如何修改代码以更新numpy数组numpy_array

【问题讨论】:

那么,你决定好答案了吗? ;-) 【参考方案1】:

joblib默认使用进程的多处理池,正如its manual所说:

在底层,Parallel 对象创建了一个多处理池, 在多个进程中分叉 Python 解释器以执行每个进程 列表的项目。延迟函数是一个简单的技巧 能够通过函数调用创建元组(函数、参数、kwargs) 语法。

这意味着,每个进程都继承了数组的原始状态,但无论它在其中写入什么,都会在进程退出时丢失。只有函数结果被传递回调用(主)进程。但是你没有返回任何东西,所以返回了None

要使共享数组可修改,您有两种方法:使用线程和使用共享内存。


与进程不同,线程共享内存。所以你可以写入数组,每个作业都会看到这个变化。根据joblib手册,是这样完成的:

  Parallel(n_jobs=4, backend="threading")(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));

当你运行它时:

$ python r1.py 
2.0

但是,当您将复杂的内容写入数组时,请确保正确处理数据或数据片段周围的锁,否则您将遇到竞争条件(google it)。

还要仔细阅读 GIL,因为 Python 中的计算多线程是有限的(与 I/O 多线程不同)。


如果您仍然需要这些进程(例如因为 GIL),您可以将该数组放入共享内存中。

这是一个更复杂的话题,但joblib + numpy shared memory example 也显示在joblib 手册中。

【讨论】:

【参考方案2】:

正如 Sergey 在他的回答中所写,进程不共享状态和内存。这就是您看不到预期答案的原因。

线程共享状态和内存空间,因为它们在同一个进程下运行。如果您有许多 I/O 操作,这很有用。由于 GIL

,它不会为您提供更多处理能力(更多 CPU)

进程间通信的一种技术是使用管理器的代理对象。您创建一个管理器对象,它在进程之间同步资源。

Manager() 返回的管理器对象控制一个服务器进程,该服务器进程保存 Python 对象并允许其他进程使用代理来操作它们。

我还没有测试过这段代码(我没有你使用的所有模块),它可能需要对代码进行更多修改,但是使用 Manager 对象应该看起来像这样

if __name__ == "__main__":
    G = nx.erdos_renyi_graph(100000,0.99);
    nRepeat = 5000;

    manager = multiprocessing.Manager()
    numpys = manager.list(np.zeros([nRepeat, G.number_of_nodes()])

    Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpys, que) for repeat_index in range(nRepeat));
    print(np.mean(numpys));

【讨论】:

那里的数据结构在语义上是一个浮点列表列表(矩阵/表),但实际上是numpy.arrays 的numpy.float64 值的numpy.array 的一个实例。通过默认管理器同步这些自定义数据类型会遇到很多麻烦,默认管理器仅支持少数标量值、本机列表和字典。

以上是关于如何在 python joblib 中写入共享变量的主要内容,如果未能解决你的问题,请参考以下文章

函数能否知道它们是不是已经在 Python 中进行了多处理(joblib)

python 使用joblib在内存中存储和检索对象

Python多处理(joblib)参数传递的最佳方式

导入sklearn时Python出错..无法从'joblib.logger'导入名称'Logger'

Python,与 joblib 并行化:延迟多个参数

多个 OpenMP 线程读取(不写入)共享变量的性能成本?