torch.distributed.barrier() 是如何工作的

Posted

技术标签:

【中文标题】torch.distributed.barrier() 是如何工作的【英文标题】:How does torch.distributed.barrier() work 【发布时间】:2020-05-02 17:54:41 【问题描述】:

我已经阅读了所有我能找到的关于 torch.distributed.barrier() 的文档,但仍然无法理解它在 this script 中的使用方式,非常感谢一些帮助。

所以official doc of torch.distributed.barrier 说它“同步所有进程。如果 async_op 为 False,或者在 wait() 上调用异步工作句柄,这个集体会阻塞进程,直到整个组进入这个函数。”

在脚本中有两个地方用到:

First place

    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

        ... (preprocesses the data and save the preprocessed data)

    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier() 

Second place

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

        ... (loads the model and the vocabulary)

    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

我无法将代码中的注释与官方文档中所述的此函数的功能联系起来。它如何确保只有第一个进程在两次调用 torch.distributed.barrier() 之间执行代码,为什么它只在第二次调用之前检查本地排名是否为 0?

提前致谢!

【问题讨论】:

【参考方案1】:

首先您需要了解排名。简而言之:在多处理上下文中,我们通常假设等级 0 是第一个进程或基本进程。然后对其他进程进行不同的排序,例如1、2、3,一共四个进程。

某些操作不需要并行完成,或者您只需要一个进程进行一些预处理或缓存,以便其他进程可以使用该数据。

在您的示例中,在非基本进程(等级 1、2、3)输入的第一个 if 语句中,它们将阻塞(或“等待”),因为它们遇到障碍。他们在那里等待,因为 barrier() 会阻塞,直到 所有 进程都到达障碍,但基础进程还没有到达障碍。

所以此时非基础进程(1、2、3)被阻塞,但基础进程(0)继续。基本进程将执行一些操作(在这种情况下预处理和缓存数据),直到它到达第二个 if 语句。在那里,基本进程将遇到障碍。至此,所有进程都已在屏障处停止,意味着可以解除屏障,所有进程都可以继续进行。因为基础进程准备了数据,所以其他进程现在可以使用该数据。

也许最重要的要理解的是:

当进程遇到障碍时,它将阻塞 屏障的位置并不重要(例如,并非所有进程都必须输入相同的 if 语句) 一个进程被屏障阻塞,直到所有进程都遇到屏障,然后所有进程的屏障都被解除

【讨论】:

感谢您的澄清:“一个进程被屏障阻塞,直到所有进程都遇到屏障,然后所有进程都解除屏障”是有道理的 @vgoklani 没问题。 感谢您的帮助解释。但我不清楚第二点。如果非基础进程在barrier等待,即使基础进程也不会进入if代码,当它通过if条件时,你的意思是所有进程都会继续吗? 很奇怪,基础进程不执行barrier(),如何通知其他进程所有进程都准备好,以便解除障碍? @QinshengZhang 进程不必进入同一个屏障,只需一个屏障。因此,如果您编写一个只有某些进程进入的 if 函数,那么只有在其他进程到达另一个障碍时它们才会继续。

以上是关于torch.distributed.barrier() 是如何工作的的主要内容,如果未能解决你的问题,请参考以下文章