如何在 python joblib 中写入共享变量
Posted
技术标签:
【中文标题】如何在 python joblib 中写入共享变量【英文标题】:How to write to a shared variable in python joblib 【发布时间】:2018-03-21 08:07:01 【问题描述】:以下代码并行化了一个 for 循环。
import networkx as nx;
import numpy as np;
from joblib import Parallel, delayed;
import multiprocessing;
def core_func(repeat_index, G, numpy_arrary_2D):
for u in G.nodes():
numpy_arrary_2D[repeat_index][u] = 2;
return;
if __name__ == "__main__":
G = nx.erdos_renyi_graph(100000,0.99);
nRepeat = 5000;
numpy_array = np.zeros([nRepeat,G.number_of_nodes()]);
Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));
print(np.mean(numpy_array));
可以看出,要打印的预期值为 2。但是,当我在集群(多核、共享内存)上运行我的代码时,它返回 0.0。
我认为问题在于每个工作人员都创建了自己的numpy_array
对象副本,而在主函数中创建的副本没有更新。如何修改代码以更新numpy数组numpy_array
?
【问题讨论】:
那么,你决定好答案了吗? ;-) 【参考方案1】:joblib
默认使用进程的多处理池,正如its manual所说:
在底层,Parallel 对象创建了一个多处理池, 在多个进程中分叉 Python 解释器以执行每个进程 列表的项目。延迟函数是一个简单的技巧 能够通过函数调用创建元组(函数、参数、kwargs) 语法。
这意味着,每个进程都继承了数组的原始状态,但无论它在其中写入什么,都会在进程退出时丢失。只有函数结果被传递回调用(主)进程。但是你没有返回任何东西,所以返回了None
。
要使共享数组可修改,您有两种方法:使用线程和使用共享内存。
与进程不同,线程共享内存。所以你可以写入数组,每个作业都会看到这个变化。根据joblib
手册,是这样完成的:
Parallel(n_jobs=4, backend="threading")(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));
当你运行它时:
$ python r1.py
2.0
但是,当您将复杂的内容写入数组时,请确保正确处理数据或数据片段周围的锁,否则您将遇到竞争条件(google it)。
还要仔细阅读 GIL,因为 Python 中的计算多线程是有限的(与 I/O 多线程不同)。
如果您仍然需要这些进程(例如因为 GIL),您可以将该数组放入共享内存中。
这是一个更复杂的话题,但joblib + numpy shared memory example 也显示在joblib
手册中。
【讨论】:
【参考方案2】:正如 Sergey 在他的回答中所写,进程不共享状态和内存。这就是您看不到预期答案的原因。
线程共享状态和内存空间,因为它们在同一个进程下运行。如果您有许多 I/O 操作,这很有用。由于 GIL
,它不会为您提供更多处理能力(更多 CPU)进程间通信的一种技术是使用管理器的代理对象。您创建一个管理器对象,它在进程之间同步资源。
Manager() 返回的管理器对象控制一个服务器进程,该服务器进程保存 Python 对象并允许其他进程使用代理来操作它们。
我还没有测试过这段代码(我没有你使用的所有模块),它可能需要对代码进行更多修改,但是使用 Manager 对象应该看起来像这样
if __name__ == "__main__":
G = nx.erdos_renyi_graph(100000,0.99);
nRepeat = 5000;
manager = multiprocessing.Manager()
numpys = manager.list(np.zeros([nRepeat, G.number_of_nodes()])
Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpys, que) for repeat_index in range(nRepeat));
print(np.mean(numpys));
【讨论】:
那里的数据结构在语义上是一个浮点列表列表(矩阵/表),但实际上是numpy.array
s 的numpy.float64
值的numpy.array
的一个实例。通过默认管理器同步这些自定义数据类型会遇到很多麻烦,默认管理器仅支持少数标量值、本机列表和字典。以上是关于如何在 python joblib 中写入共享变量的主要内容,如果未能解决你的问题,请参考以下文章
函数能否知道它们是不是已经在 Python 中进行了多处理(joblib)