喜马拉雅基于 HybridBackend 的深度学习模型训练优化实践

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了喜马拉雅基于 HybridBackend 的深度学习模型训练优化实践相关的知识,希望对你有一定的参考价值。

喜马拉雅作者:李超、陶云、许晨昱、胡文俊、张争光、赵云鹏、张玉静


喜马拉雅AI云借助阿里云提供的HybridBackend开源框架,实现了其推荐模型在 GPU 上的高效训练。

业务介绍

推荐场景是喜马拉雅app的重要应用之一,它广泛应用于热点、猜你喜欢、私人FM、首页信息流、发现页推荐、每日必听等模块。这些模块都依赖于喜马拉雅AI云,这是一套从数据、特征、模型到服务的全流程一站式算法工具平台。

喜马拉雅基于

推荐服务的一个核心诉求是能快速捕捉和反映用户不断变化的兴趣和当前热点,这就要求模型能在短时间内,以可控的成本完成对海量用户数据的训练。使用GPU等高性能硬件来加速模型训练已经成为CV, NLP等领域的行业标准;在使用稀疏训练数据的推荐场景下,国内外的各大厂商也在积极转向使用高性能GPU来替代传统的CPU训练集群,以提升训练的效率。

喜马拉雅基于

喜马拉雅AI云借助阿里云机器学习平台PAI的开源框架HybridBackend,实现了其推荐模型在 GPU 上的高效训练。在加速训练的同时, HybridBackend 框架高度易用,帮助其算法团队提升了开发效率。

问题与挑战

随着推荐业务的底层训练硬件逐渐从CPU向GPU转变,我们在生产实践中发现传统的训练方式存在严重的计算资源利用率不足的问题。经过调查与分析,我们发现计算资源利用不足主要来自于稀疏数据访问和分布式训练两方面:

  • 稀疏数据访问:我们使用经典机器学习中常用的 libsvm 数据格式来存储数据,将多个特征合并成一个稀疏字符串表达。在训练时,训练节点从远端的存储(如 OSS)上下载字符串,并从字符串中切分出多个特征输入,然后再喂入对应的 Embedding Table。在特征维度爆炸性增长的情况下,拼接字符串的数据量很大,导致数据读取严重受制于网络带宽;同时切分字符串也造成了 CPU 资源的消耗。
  • 分布式训练:我们尝试过多种分布式训练方式。起初,我们使用 keras+horovod 实现多GPU分布式训练,但在具体使用过程中发现有不少问题,比如出现加速不稳定、模型指标恶化等现象。后来,我们自研实现了一版基于参数服务器(PS)的分布式训练框架,通过内部的 xcache 服务实现 embedding 存储管理并进行线上同步,并使用自研的pspull和pspush算子进行embedding表的更新,一定程度上解决了分布式训练的效率问题。但在进一步增大训练数据量时发现,引入的 ps 算子因为频繁的 IO 交互成为了训练速度的瓶颈,降低了GPU设备利用率,同时 xcache 服务存储变长embedding 的支持成本很高,限制了算法工程师的优化空间。

HybridBackend

我们在调研如何解决上述问题和探索未来技术发展空间时发现了阿里云正在推广的开源框架 HybridBackend,该框架对稀疏模型训练过程中的数据访问、稀疏计算以及分布式训练都进行了深度优化(见图1),并提供了简单易用的接口。令人惊喜的是,这个框架兼容性很强,可以支持 TensorFlow、DeepRec 等多种训练框架,可以很好地满足我们服务不同业务客户的需求。此外,相关的架构和系统设计已经以论文形式在ICDE2022会议上公开,并且在Github上开源了主要功能,可以直接以pip方式安装。

喜马拉雅基于

图1.  HybridBackend的基本功能模块

图2 描述了在我司模型训练任务中落地 HybridBackend 的全景示意图。蓝色框代表了 HybridBackend 参与或加速了的流程部分。可以看到基本涵盖了全部模型训练流程,  下面重点介绍在数据读取和分布式训练上的优化成果。

图2. HybridBackend在喜马拉雅业务流程中的落地

稀疏数据访问优化

HybridBackend 提供了 hb.data.Dataset 接口 ,通过支持如 Parquet 这样的列存数据格式,可以极大加速稀疏数据的访问。如表1所示,HybridBackend 框架稀疏数据读取性能远高于其他实现。

文件格式

文件大小(MB)

框架

线程数

耗时(ms)

CSV

11062.61

TensorFlow

1

8858.38

Parquet (SNAPPY)

3346.10

TensorFlow I/O

1

103056.17

Parquet (SNAPPY)

3346.10

HybridBackend

1

397.88

Parquet (SNAPPY)

3346.10

HybridBackend

20

21.67

在我们的实际应用中,HybridBackend 稀疏数据访问功能中的一些功能效果显著:

  • 数据列选择性解析:我们将需要原有的类libsvm格式切换成宽表格式,其中每列对应一个特征。HybridBackend 可以支持在读取 Parquet 文件时只读取选择的字段,并将字段数据解析成 TensorFlow 所需要的格式,如自动将 list 类型的数据转换为 SparseTensor,或将 list 类型的数据进行填充截断后转换为 Tensor,满足了我们数据加载的多种需要。
  • 数据读取并行度设置:HybridBackend 可以通过设置num_parallel_reads 参数来调整读不同文件的并行度,通过设置num_parallel_parser_calls 参数来调整读文件中不同列的并行度。通过并行读取,在充分利用机器 CPU 资源的同时,加速了数据读取的性能。

在使用 HybridBackend 后,数据访问不再是我们的训练瓶颈。单卡训练的 GPU 平均利用率提升了 3x 以上,业务模型的训练周期显著缩短。

分布式训练优化

HybridBackend 提供了混合并行训练模式(如图3),每张 GPU 都会存储全部的稠密参数和部分的稀疏参数,并使用可以利用 NVLink 的 NCCL 通信协议来代替传统 PS 训练方式所使用的 RPC 协议。

喜马拉雅基于

图3. HybridBackend提供的混合并行训练模式

根据我们对未来一段时间内模型特征维度和大小的预估,以及我们对训练速度的需求,我们采用了 HybridBackend 混合并行方式进行训练,有效地提高了训练速度和 GPU 利用率。

我们还与 HybridBackend 社区的开发者协同工作,促进了 HybridBackend 对 Keras Model API 的支持,使我们能够在 Keras Model API 下利用 HybridBackend 进行混合并行,并实现模型热启等重要功能。这些功能极大地降低了使用成本。

总体收益

整体流程改造完毕之后,我们在推荐场景中,单机多卡训练 GPU 平均利用率提升了1.4x 以上(视具体模型不同),训练环节整体耗时减少50%以上。目前我们已经在使用了 Tensorflow 和 DeepRec 的模型中全量推广基于 HybridBackend 的训练方案。

喜马拉雅基于

未来规划

喜马拉雅 AI 云平台目前覆盖了喜马拉雅多个app的推荐、广告、搜索推荐等核心业务场景,以及画像产出、数据分析、BI数据生成等定制化开发场景。我们也在探索后续与 HybridBackend 社区的一些合作,以便更好地满足业务需求:

● 算子优化:HybridBackend支持了embedding lookup 过程中的各种算子的融合优化。我们会尝试通过这种方式提升模型在线推理的性能。

● PyTorch支持:NLP 搜推场景中有用 Pytorch 进行训练和部署的需求。我们需要HybridBackend能够支持该场景的实现。

● 超大型分布式训练:我们的模型训练级别达到了百亿样本十亿特征维度。随着算法复杂度的提升,我们需要支持更大的数据量和更高的维度的训练。

鸣谢

在合作共建阶段,我们得到了 HybridBackend 社区 陈浪石、袁满等的技术支持,他们技术高超、服务周到、响应及时。帮助我们快速完成了深度学习模型的训练流程优化,为我们的业务指标和算法优化空间带来了明显的提升。在此表示衷心的感谢!

HybridBackend 社区

欢迎在 GitHub 上 star 和提 issue,也可以直接在钉钉群中联系 HybridBackend 社区。

GitHub 地址:

​https://github.com/alibaba/HybridBackend​

喜马拉雅:基于 WeNet 和 gRPC 的语音识别微服务架构的设计和应用

近日,喜马拉雅语音团队在wenet中增加了基于gRPC的流式语音识别的支持。本文由喜马拉雅语音团队撰写,介绍wenet中的gRPC的设计和实现,并介绍喜马拉雅基于wenet和gRPC的语音识别微服务架构的设计和应用。

喜马拉雅科技有限公司是中国最大的有声读物平台,于2012年8月成立,2021年第一季度喜马拉雅全场景流量月活用户达到2.50亿。喜马拉雅AI语音组是喜马拉雅的核心部门,专注于语音合成、识别、语音信号处理、编解码以及智能音效的研究和开发,同时对接公司内外的多项业务和落地场景。

wenet介绍

wenet是由出门问问公司推出的一款开源ASR工具,自问世开始便积聚了大量人气。常规的ASR工具,或把模型训练过程中的复杂细节进行封装,为算法人员的模型训练过程提供便利;或把kaldi的速度与pytorch的便利通过脚本的方式结合了起来,加速了模型的训练过程。

wenet与以往工具不同之处在于,自问世起,它就同时提供了基于python/pytorch的训练脚本和基于c++/libtorch的工程化部署方案,是真正面向工业界的ASR工具。

喜马拉雅:基于 WeNet 和 gRPC 的语音识别微服务架构的设计和应用

websocket介绍

既然谈到工业界,就不得不谈到服务。业界大部分公司的流式ASR服务,都会支持websocket接口,这是因为websocket支持双向的流式数据传输,

wenet自身也提供了基于websocket的服务端/客户端示例。诚然,websocket具备轻量级、简单易用的优点,但在大型服务体系中,它的缺点也很明显。

使用过wenet的websocket客户端的朋友们应该清楚,流式识别过程分为三个步骤。wenet中的协议如下,首先,设置模式为发送text,发送形如{{"signal", "start"}{"nbest",1}}的json字符串标志开始,然后,设置模式为发送bytes,发送pcm数据,最后,设置模式为发送text,发送{{"signal", "end"}}的json字符串标志结束。websocket服务端的解析逻辑也是相当麻烦。首先需判断获取的数据是text还是bytes,再执行相应逻辑。若为text,还需尝试进行json解析,判断是否存在signal/nbest这些key,解析过程需加上大量的异常处理逻辑。

对接口设计比较敏感的朋友们肯定已经发现了websocket的弊端,即接口无硬性约束,全靠文档或口头对接,调用过程极易出错。肯定有朋友会说,哎呀我不想写这些乱七八糟的消息构造/解析代码,我就想关注服务调用的主要流程,有没有办法避免掉这些dirty work呢?有的,答案就是gRPC。

gRPC+protobuf介绍

gRPC是由google开发的一套RPC框架,基于http2.0,支持双向流式通信。gRPC的通信使用的是protobuf协议,接口定义更加清晰,还能减少数据传输量。gRPC+protobuf的好处不再赘述,对它们有深入兴趣的朋友们可以看一些详细介绍的博客。总结一下,google大法好。

有一点需要提醒的是,由于gRPC存在多路复用的概念,若朋友们基于k8s平台部署,常规的基于连接层的负载均衡器不会生效。但在我们看来,这恰恰是gRPC先进性的体现,说明它步子迈太大,兄弟们没跟上。大家或者可以使用nginx1.13.10以上版本进行请求转发,或者像喜马拉雅内部有一套consul服务发现/注册系统,客户端可以自行实现负载均衡策略。

wenet的proto设计

gRPC的接口定义描述在以proto为后缀的文件中。在gRPC中,每一个返回字段都有明确的类型定义。从而使得服务端的开发人员,仅通过proto文件,就可轻松得知如何调用gRPC服务。

针对wenet,我们设计了如下的proto文件,位于https://github.com/wenet-e2e/wenet/blob/main/runtime/core/gRPC/wenet.proto。分块介绍如下:

service ASR {
  rpc Recognize (stream Request) returns (stream Response) {}
}

以上部分说明,我们的服务名为ASR,实现了Recognize方法,输入输出皆为流式(由stream关键字标示)

message Request {
  message DecodeConfig {
    int32 nbest_config = 1;
    bool continuous_decoding_config = 2;
  }
  oneof RequestPayload {
    DecodeConfig decode_config = 1;
    bytes audio_data = 2;
  }
}

流式请求的Request是DecodeConfig/bytes中的一种(oneof关键字),其中DecodeConfig包含nbest_config/continuous_decoding_config两个字段,分别为int/bool类型

message Response {
  message OneBest {
    string sentence = 1;
    repeated OnePiece wordpieces = 2;
  }
  message OnePiece {
    string word = 1;
    int32 start = 2;
    int32 end = 3;
  }
  enum Status {
    ok = 0;
    failed = 1;
  }
  enum Type {
    server_ready = 0;
    partial_result = 1;
    final_result = 2;
    speech_end = 3;
  }
  Status status = 1;
  Type type = 2;
  repeated OneBest nbest = 3;
}

流式请求的Response含status/type/nbest三个字段,其中status/type为枚举类型,说明它们的赋值必须为范围内的一种。nbest为repeated OneBest类型,OneBest则由sentence和wordpieces字段组成。

大家可以看到,服务端/客户端无需存在任何hard-code的代码,所有字段都可以从Request/Response中以属性的方式获取。

wenet的gRPC实现

基于wenet的gRPC实现,其代码已经进行了merge,代码位于https://github.com/wenet-e2e/wenet/tree/main/runtime/core/{grpc,bin}

首先我们需编译proto文件,得到wenet.grpc.pb.h,wenet.grpc.pb.cc,wenet.pb.h,wenet.pb.cc四个文件。细心的朋友们肯定会发现,wenet.pb.h/cc中存储了protobuf数据格式的定义,wenet.grpc.pb.h中存储了gRPC服务端/客户端的定义。

gRPC服务端

gRPC服务端对纯虚基类ASR::Service进行继承并实现即可。

在Recognize方法中,我们做法与wenet自带的websocket完全相同,即每来一个gRPC请求,初始化一个GRPCConnectionHandler进行处理。通过ServerReaderWriter类型的stream对象,即可实现双向流式通信。

Status GrpcServer::Recognize(ServerContext* context,
                             ServerReaderWriter<Response, Request>* stream) {
  LOG(INFO) << "Get Recognize request" << std::endl;
  auto request = std::make_shared<Request>();
  auto response = std::make_shared<Response>();
  GrpcConnectionHandler handler(stream, request, response, feature_config_,
                                decode_config_, symbol_table_, model_, fst_);
  std::thread t(std::move(handler));
  t.join();
  return Status::OK;
}

gRPC客户端

客户端则需实例化ASR::Stub,通过ClientReaderWriter类型的stream对象,即可实现双向流式通信。

void GrpcClient::Connect() {
  channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_),
                                 grpc::InsecureChannelCredentials());
  stub_ = ASR::NewStub(channel_);
  context_ = std::make_shared<ClientContext>();
  stream_ = stub_->Recognize(context_.get());
  request_ = std::make_shared<Request>();
  response_ = std::make_shared<Response>();
  request_->mutable_decode_config()->set_nbest_config(nbest_);
  request_->mutable_decode_config()->set_continuous_decoding_config(
      continuous_decoding_);
  stream_->Write(*request_);
}

喜马拉雅的流式ASR架构设计

通过以上介绍,有些动手能力强的朋友们可能会觉得,不管是gRPC,还是websocket服务,也就那么回事嘛,基于wenet搭建一个服务也并非难事。

喜马拉雅:基于 WeNet 和 gRPC 的语音识别微服务架构的设计和应用

其实我们也完全同意这样的看法,搭建一个流式服务并不难。但搭建一个流式服务架构,并在满足业务需求的同时,减轻后端开发/算法同学的工作量,则是一件讲究的事情。接下来,我们将介绍喜马拉雅语音团队及目前的流式微服务架构。

服务开发/算法训练过程中的问题

各位算法/开发同学,有没有碰到过以下情况:

1.我用Pytorch训了个模型,但可惜我们的服务是基于Kaldi的,上不了线

2.我开发了个预处理模块,用了一些Python的库,但我们的服务是Java/C++的,改写难度太大

3.我的服务出bug了,因为开发同学把我的算法逻辑写错了

4.业务有个新需求,要调用新的算法模块。开发同学只好加班加点修改线上服务,把新的算法代码添加到服务中

......

以上问题,如何解决?答案很简单,云原生(当然云原生也不是万能的,还是看服务架构是否适合云原生)

喜马拉雅:基于 WeNet 和 gRPC 的语音识别微服务架构的设计和应用

云原生的微服务如何解决该问题

在云原生的理念中,微服务/RPC是其中的精髓。

微服务解决了算法快速上线的问题,不管你是Python/Java/C++/Go,不管你是tensorflow/pytorch/kaldi,你只要将自己的算法模块以微服务的形式上线即可。RPC则提供了对外调用的方式。通过微服务的组合,我们可以对业务快速输出新的能力。

喜马拉雅的流式ASR架构

喜马拉雅的流式ASR目前初步分为4个微服务--业务接入服务/流式VAD服务/流式ASR服务/流式后处理服务。

喜马拉雅:基于 WeNet 和 gRPC 的语音识别微服务架构的设计和应用

 基于websocket的业务接入服务

业务接入服务,即所有业务的入口。对于业务方,我们提供最为通用的接口,即websocket调用方式(当然对于愿意配合的业务方,我们也可以提供原生的gRPC调用方式)。

每个业务在调用我们的流式服务之前,需在接入服务进行注册。注册的内容目前包含:采样率/比特数/通道数/是否定制模型/是否使用热词/何种后处理算法等等配置项。业务接入服务会根据不同业务的配置情况,自由组合后端的流式VAD/ASR/后处理服务,对业务进行输出。

所有的业务逻辑/预处理,如数据统一规整为16k/16bit,都会在业务接入服务进行处理。后端算法服务中不含任何业务逻辑。

喜马拉雅:基于 WeNet 和 gRPC 的语音识别微服务架构的设计和应用

基于gRPC的流式VAD/ASR/后处理服务

流式VAD/ASR/后处理服务,我们可以统称为算法服务。因为大部分算法人员对于C++/Java开发并不熟悉,后端算法服务统一基于gRPC进行封装,从而发挥了gRPC跨语言的优势。每个算法模块,其算法验证、服务上线、业务跟进由同一个人负责,极大的增加了算法人员的效率。再也没有算法人员与开发人员互相沟通,最后出现问题面面相觑的情景。

如英文ASR模型开发完毕,由该算法人员使用任意语言如python进行gRPC封装,使用公司的发布平台进行上线,最终向业务接入服务的负责人员提供一个接口即可。

采用微服务架构后的现状

喜马拉雅的流式微服务架构,即吸取了websocket的轻量级优势,又融合了gRPC的工程开发的便捷性,使得喜马拉雅的流式ASR服务可以快速相应业务需求。

对比常规的单体流式服务,从开发的角度,算法/开发人员的工作量极大减轻,只需使用自己最习惯的语言进行服务模块上线即可。从业务的角度,每当有新的业务需求,如业务A无需调用VAD/业务B需使用数字规整等,业务接入服务将后端算法能力进行组合即可,工作量很小。若后端能力已经具备,可以随时上线。

有些比较细心的同学们可能会提出疑问,微服务拆分后增加了模块间通信时间,会不会对实时率有较大影响?我们实践下来发现,微服务拆分带来的额外延时是很小的。一组相关的服务,通常会部署在同一个机房,甚至在k8s的同一个pod中。其通信时间以ms计。相比于给开发/算法人员带来的便利性而言,我们认为额外几十ms的延时是值得的。

往期精彩





We make the Net better
长按二维码关注


以上是关于喜马拉雅基于 HybridBackend 的深度学习模型训练优化实践的主要内容,如果未能解决你的问题,请参考以下文章

阿里开源自研工业级稀疏模型高性能训练框架 PAI-HybridBackend

喜马拉雅:基于 WeNet 和 gRPC 的语音识别微服务架构的设计和应用

深度学习之基于DCGAN实现动漫人物的生成

喜马拉雅基于DeepRec构建AI平台实践

深度学习之文本分类模型-基于CNNs系列

深度学习之文本分类模型-基于CNNs系列