[源码解析] 快手八卦 --- 机器学习分布式训练新思路

Posted 罗西的思考

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[源码解析] 快手八卦 --- 机器学习分布式训练新思路相关的知识,希望对你有一定的参考价值。

[源码解析] 快手八卦 — 机器学习分布式训练新思路(1)

0x00 摘要

“Bagua“ 是快手和苏黎世理工(ETH Zürich)联合开发的分布式训练框架。其专门针对分布式的场景设计特定的优化算法,实现算法和系统层面的联合优化,力图极致化分布式训练的效率。其特点是:

  • 并行性能显著提高;

  • 对网络环境更鲁棒;

  • “一键式”使用;

  • 分布式通讯算法易拓展性;

  • 可用于工业级场景大规模使用;

  • 安全、故障易排查;

本文以:

为基础来分析学习。本文学习“bagua"总体设计思路和负载均衡数据加载器。

0x01 设计思路

以下摘录于快手官方帖子 快手八卦!突破 TensorFlow、PyTorch 并行瓶颈的开源分布式训练框架来了! 和 ETH PPT,按照自己理解有调整。

1.1 如何通信

在数据并行之中,从单机单卡的训练到多机多卡训练的核心,是每个卡把自己的计算结果进行累加和传播,所以一个关键点是两个worker之间如何进行通信。

这个过程好比每个人把自己知道的信息传递给他人,然后又从其他人那里获取信息,最后完成全局的信息同步。如果把计算单元之间的信息同步类比为人与人之间的信息同步,那么社会实践经验告诉我们,“八卦”可能是消息传递最高效的模式。“八卦”消息传播具有去中心化、异步通讯、信息压缩的特点,这与 Bagua 里面实现的通讯算法刚好一一呼应。

1.2 通信模式分类

针对通信模式,有如下分类。

1.2.1 系统架构

按照系统架构来区分,是参数服务器和Allreduce。

下图是参数服务器和Allreduce范式的图例。

  • 参数服务器架构中,模型可以被分割成分片(shard)并分布到多个节点(我们称这些节点为 “参数服务器”)。在训练阶段,worker定期从参数服务器获取模型,利用计算单元(如GPU)进行前向和后向传播,并将梯度推送给参数服务器,而参数服务器汇总梯度并更新参数。
  • Allreduce范式之中,所有worker都与他们的邻居合作进行模型/梯度交换。现有的系统通常采用环形拓扑结构进行两阶段的交流:首先,范式将模型/梯度划分为n个块(其中n为节点数),并使用不同起点和终点的n个环来聚合n个块;其次,位于不同节点的每个块的聚合结果会在环内进行广播。

1.2.2 同步角度

从通信同步角度看可以分为同步或是异步(Synchronous or Asynchronous):

  • 同步模式中,在每一次迭代过程中,所有工作节点都需要进行通信,并且下一步迭代必须等待当前迭代的通信完成才能开始。
  • 反之,异步式分布算法 则不需要等待时间:当某个节点完成计算后就可直接传递本地梯度,进行模型更新。

1.2.3 通信拓扑

从通信拓扑角度看可以分成中心化或是去中心化(Centralized or Decentralized):

  • 在中心化的通讯模式中,梯度或模型的同步过程需要所有的工作节点进行参与,因此,较高的网络延时往往会导致训练效率的降低。
  • 去中心化的通信模式往往可以有效的解决上述问题:在该模式下,工作节点可以被连接成特定的拓扑结构(例如环),在通信过程中,每一个工作节点只与和它相邻的节点进行通信。

1.2.4 压缩

从通信压缩与否角度看,有完整精度模式或信息压缩模式(Full-Precision or Low-Precision)两种:

  • 完整精度模式会使用与本地模型相同的 32 位浮点数(float32)进行传输。
  • 另一方面,在通讯存在瓶颈的情况下,基于大量已有研究通过量化 (quantization) 或稀疏化 (sparsification) 等方法压缩梯度,再用压缩后的梯度更新参数。在很多场景下,可以达到和完整精度相同的精度,同时提升通讯效率。

1.3 挑战

快手在实现之中,遇到了三个挑战:

  • 理论基础:通信模式需要有理论的支撑,需要严格在理论上证明通信是有效的,收敛的。
  • 系统设计:现有分布式学习系统都无法满足所有的新的通信模式,所以需要设计新的系统结构,才能利用这种算法带来的优势。
    • 参数服务器基本操作put/get,无法实现去中心化和误差补偿。
    • Allreduce是全局性的,无法实现去中心化或者异步模式。
  • 评测:需要在大规模真实场景下对各种算法进行评测。

1.4 Bagua 实现

1.4.1 分层

Bagua 具体分为三层:

  • 算法层:在逻辑层基础之上,实现了具体算法,比如某一个算法是去中心化,压缩,异步的。
  • 逻辑通信层:在物理通信层基础之上,实现了多种通信原语,比如去中心化,精度,同步等等,这些通信原语不是针对某一类算法特殊设计的,而对上层是统一的。
  • 物理通信层:在此层集成了一些常见通信库,从而提供了基本的send,receive操作。

1.4.2 通信算法选项

针对通信模式分类,Bagua 相应将通信过程抽象成了如下的算法选项:

  • 中心化或是去中心化(Centralized or Decentralized)。

  • 同步或是异步(Synchronous or Asynchronous)。

  • 完整精度模式或信息压缩模式(Full-Precision or Low-Precision)。

虽然为了提升通讯效率,Bagua 没有依照传统的方式同步所有计算节点的结果,甚至每次同步的信息还有偏差,但是得益于最新理论上的进展,这几种通讯策略以及他们的组合最终收敛解的正确性和效率仍然能得到充分保证,而且计算复杂度跟同步中心化和信息无损的方法相当,但是通讯效率更高。

Bagua 提供了一套详尽的通信模式来支持用户在上述模式中任意选择组合,我们将这一分布式训练系统对于上述算法选项的支持情况总结在下表中:

从表格中不难看出,现有框架的优化只是针对较为通用的算法(中心化同步完整精度),对于其他的算法组合,这些系统的支持非常有限。对于中心化同步进行信息压缩,这些系统往往只能支持较为简单的 float32->float16 压缩,相较而言,Bagua 则可以支持更为复杂的 ByteGrad,QAdam 等算法。对于其他的算法组合,现有的框架通常无法支持,而 Bagua 则可以自由支持。

1.4.3 总体

BAGUA的核心是一个训练算法,由开发者使用BAGUA提供的通信原语和抽象概念来实现。算法将最终用户提供的神经网络作为输入,并为其配备一个特定于算法的通信功能。具体来说,算法的开发者会在执行的不同阶段将这个通信功能注册为钩子。

1.4.4 优化

然而,简单地支持算法选项并不能直接在大规模集群上带来性能的提升。Bagua 的核心优势在于,为了追求极致化的性能,而实现算法和实现的联合优化。具体来讲,基于上述的通信层抽象,用户既可以方便得选择系统提供的各种算法组合从而获得性能提升,又能灵活得实现新的分布式 SGD 算法 —— Bagua 将自动为这一算法实现提供系统层优化。这些系统优化包含:

  • 将通讯时间隐藏在计算时间中。
  • 参数分桶及其内存管理。
  • 分层化的通信实现。

想要强调的是,这些系统实现层面的优化是对于各种算法组合广泛适用,而非局限在某一特定的算法设置上。因此,所有的系统优化都可以被灵活的复用到各种算法实现中去,这在保证“端到端”的性能提升的同时,也为开发新的分布式算法提供了良好的平台。

1.5 流程图

我们使用官方号的图例做一下总结

0x02 分析思路

通过官方文章我们可以发现对于分析学习来说有如下情况:

  • 通信方面的优化实现是八卦项目的一大特点。
  • 底层 Rust 语言笔者不熟悉。
  • 通盘研究整体代码不现实。

因此我们决定以 中心化、异步通讯,分层化的通信实现 为中心,再结合几个特色实现来学习分析。本文学习负载均衡数据加载器。

0x03 Load Balanced Data Loader

在某些场景下当训练数据中样本的计算复杂度是不同的,比如在 NLP 和语音任务中每个样本的长度就不同。这时,使用八卦的负载均衡数据加载器可以大大提高分布式训练吞吐量,在这种情况下,worker 的工作负载是相似的。我们接下来就从实例入手,看看如何实现数据加载的负载均衡

我们先看看负载均衡的需求,假如我们有两个模型副本进行数据并行,有如下数据,假如这些数据代表的是数据复杂度(会影响计算时间)

[ 7,  1, 11,  5,  10,  2,  9, 4,  6,  0,  8,  3]

那么第一个模型副本收到的数据为:[7,11,10,9,6, 8]。第二个模型副本收到的数据为:[1,5,2,4,0,3]。可以看出来两个模型在每个batch收到数据的复杂度不同,会造成负载不均衡。

                         +  8                         + 3
                         |                            |
                         |  6                         | 0
                         |                            |
                         |  9                         | 4
                         |                            |
batch 3   +----------->  |  10                        | 2  <----------+  batch 3
                         |                            |
batch 2   +----------->  |  11                        | 5  <----------+  batch 2
                         |                            |
batch 1   +----------->  v  7                         v 1  <----------+  batch 1

                  +-------------------+        +-------------------+
                  |                   |        |                   |
                  |     worker 0      |        |     worker 1      |
                  |                   |        |                   |
                  |                   |        |                   |
                  +-------------------+        +-------------------+

理想状态应该是两个模型每个batch收到的数据复杂度都相仿,比如第一个模型收到 [1,3,5,7,9],第二个模型的数据是[2,4,6,8,10],在下图的输入下,可以看到每次batch数据复杂度相仿,从而达到负载均衡的效果:

                         +                            +
                         |  9                         | 10
                         |                            |
                         |  7                         | 8
                         |                            |
batch 3   +----------->  |  5                         | 6  <----------+  batch 3
                         |                            |
batch 2   +----------->  |  3                         | 4  <----------+  batch 2
                         |                            |
batch 1   +----------->  v  1                         v 2  <----------+  batch 1

                  +-------------------+        +-------------------+
                  |                   |        |                   |
                  |     worker 0      |        |     worker 1      |
                  |                   |        |                   |
                  |                   |        |                   |
                  +-------------------+        +-------------------+

3.1 使用

我们直接使用源码中的例子修改学习一下。

import torch
from load_balancing_data_loader import LoadBalancingDistributedSampler
from torch.utils.data import TensorDataset, DataLoader

def test_load_balancing_distributed_batch_sampler():
    num_replicas = 2 # 分成两个副本
    total_batch = 3 

    n = sum([i + 1 for i in range(total_batch)]) * num_replicas
    dataset = TensorDataset(torch.randn(n, 2), torch.randperm(n))

    sampler = LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=lambda x: x[1],
        num_replicas=num_replicas,
        rank=0,
        shuffle=True, # 需要shuffle
        random_level=0.5, # 加入随机
    )

    dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler)

    cur_idx = 0
    for i, data in enumerate(dataloader):
        batch_size = data[0].shape[0]
        cur_idx += batch_size * num_replicas
        print(cur_idx)

test_load_balancing_distributed_batch_sampler()

因为此处代码十分绕,所以我们逐次解析。

3.2 生成数据集

首先是生成数据集部分。torch.randn(n, 2) 生成了随机张量,torch.randperm(n) 生成了 n 的随机排序。这里假定 n 是12。

# 生成了数据集
n = sum([i + 1 for i in range(total_batch)]) * num_replicas
dataset = TensorDataset(torch.randn(n, 2), torch.randperm(n))

TensorDataset 类似 zip 命令,生成了tuple列表。

dataset = TensorDataset: 12 
 tensors = tuple: 2 (
   
  0 = Tensor: 12 tensor([[-1.5556,  0.6848],\\n        [ 2.0811,  1.5011],\\n        [ 0.7434, -0.4990],\\n        [-0.2706,  1.7227],\\n        [ 0.2179,  0.0622],\\n        [-0.3014, -0.6435],\\n        [-0.1773, -1.3405],\\n        [-1.8212,  0.3702],\\n        [-0.5526, -0.2077],\\n        [-1.6543,  0.3109],\\n        [ 0.3265,  0.5987],\\n        [-1.5566,  0.2854]])
   
   1 = Tensor: 12 tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3])

得出目前的TensorDataset如下 ,0 是实际数据,1 是数据复杂度,后续处理的目的就是按照数据复杂度对这些张量排序。我们可以设想下,最终排序应该就是一个复杂度均匀的排序结果。

+-----------------------------------------------------------------------------+
| TensorDataset                                                               |
|                                                                             |
|   0 = Tensor: 12 tensor([[-1.5556,  0.6848],......                        |
|                                                                             |
|   1 = Tensor: 12 tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3]) |
|                                                                             |
+-----------------------------------------------------------------------------+

3.3 初始化

我们来到了 LoadBalancingDistributedSampler 的初始化。

def __init__(
    self,
    dataset: Dataset,
    complexity_fn: Callable[..., int],
    num_replicas: Optional[int] = None,
    rank: Optional[int] = None,
    shuffle: bool = True,
    seed: int = 0,
    drop_last: bool = False,
    random_level: float = 0,
) -> None:
    if num_replicas is None:
        num_replicas = dist.get_world_size()
    if rank is None:
        rank = dist.get_rank()

    self.dataset = dataset
    self.num_replicas = num_replicas
    self.rank = rank
    self.epoch = 0
    self.drop_last = drop_last

    # If the dataset length is evenly divisible by # of replicas, then there
    # is no need to drop any data, since the dataset will be split equally.
    dataset_len = len(self.dataset)  # type: ignore
    if self.drop_last and dataset_len % self.num_replicas != 0:  # type: ignore
        # Split to nearest available length that is evenly divisible.
        # This is to ensure each rank receives the same amount of data when
        # using this Sampler.
        self.num_samples = math.ceil(
            # `type:ignore` is required because Dataset cannot provide a default __len__
            # see NOTE in pytorch/torch/utils/data/sampler.py
            (dataset_len - self.num_replicas)
            / self.num_replicas
        )
    else:
        self.num_samples = math.ceil(dataset_len / self.num_replicas)  # type: ignore
    self.total_size = self.num_samples * self.num_replicas
    self.shuffle = shuffle
    self.seed = seed

""" 
此时变量为
self = LoadBalancingDistributedSampler: 6 
 dataset = TensorDataset: 12 <torch.utils.data.dataset.TensorDataset object at 0x7ff7385aecf8>
 drop_last = bool False
 epoch = int 0
 num_replicas = int 2
 num_samples = int 6
 rank = int 0
 seed = int 0
 shuffle = bool True
 total_size = int 12 
"""       
    
    # 以下是与PyTorch原生的主要不同之处
    self.item_complexity_map = dict()
    for item_index in range(dataset_len):
        # 每一个item都有一个complexity
        self.item_complexity_map[item_index] = complexity_fn(
            self.dataset[item_index]
        )

"""
complexity_fn 是选取 tuple 的第二个元素作为复杂度,我们回忆一下数据集的复杂度
Tensor: 12 tensor([ 7,  8, 11,  4,  5,  2,  9, 10,  0,  6,  1,  3])

所以得到了复杂度map如下:
item_complexity_map = dict: 12 0: tensor(7), 1: tensor(8), 2: tensor(11), 3: tensor(4), 4: tensor(5), 5: tensor(2), 6: tensor(9), 7: tensor(10), 8: tensor(0), 9: tensor(6), 10: tensor(1), 11: tensor(3)
 0 = Tensor tensor(7) # 第 0 个元素复杂度是 7
 1 = Tensor tensor(8) # 第 1 个元素复杂度是 8
 2 = Tensor tensor(11)
 3 = Tensor tensor(4)
 4 = Tensor tensor(5)
 5 = Tensor tensor(2)
 6 = Tensor tensor(9)
 7 = Tensor tensor(10)
 8 = Tensor tensor(0)
 9 = Tensor tensor(6)
 10 = Tensor tensor(1)
 11 = Tensor tensor(3)
"""        
        
    # 按照复杂度排序    
    self.ordered_item_complexity_map = OrderedDict(
        sorted(self.item_complexity_map.items(), key=lambda t: t[1])
    )
    
"""

以上是关于[源码解析] 快手八卦 --- 机器学习分布式训练新思路的主要内容,如果未能解决你的问题,请参考以下文章

spark 分布式训练原理解析

spark 分布式训练原理解析

[源码解析] 深度学习分布式训练框架 horovod --- 后台线程架构

[源码解析] PyTorch 分布式之弹性训练---Rendezvous 引擎

机器学习算法实现解析——word2vec源码解析

分布式机器学习第3章——分布式机器学习框架