pytorch 可以优化顺序操作(如张量流图或 JAX 的 jit)吗?

Posted

技术标签:

【中文标题】pytorch 可以优化顺序操作(如张量流图或 JAX 的 jit)吗?【英文标题】:Can pytorch optimize sequential operations (like a tensorflow graph or JAX's jit)? 【发布时间】:2020-02-24 00:55:57 【问题描述】:

原来tensorflow和pytorch有根本的区别:

tensorflow 基于计算图。构建此图并在会话中对其进行评估是两个独立的步骤。使用时,图表不会发生变化,因此可以进行优化。 torch 急切地评估张量上的操作。这使 API 更方便(无会话),但也失去了识别和优化总是按顺序发生的操作的潜力。

现在这种差异变得不那么明显了。 Tensorflow 通过tf eager 回应了火炬的流行。还有 JAX 项目,它建立在与 tensorflow (XLA) 相同的底层框架上。 JAX 没有会话的概念。但它允许您通过简单地调用jit 将多个操作一起编译。

自从 Tensorflow 开始涵盖 PyTorch 功能后,PyTorch 是否也在努力整合 Tensorflow 的优势? PyTorch(或其路线图)中是否有类似会话或 jit 功能的东西?

API 文档有一个 jit section,但据我所知,这更多是关于导出模型。

【问题讨论】:

【参考方案1】:

正如您所提到的,有一个torch.jit,它的目的也是在导出的图中引入优化(例如内核融合、常量优化等)。 IIRC 你可以在他们的 github repo here 中找到一些源代码,但我不确定这些源代码是否在文档中的某处明确提及(或明确足以被记住)。

由于1.3 还引入了量化(有关一些介绍,请参见here)。在教程部分,即here,您可以看到Conv2dBatchNormReLU 的显式融合以提高性能。 Ofc 还存在特定的东西,例如使用 int 而不是 float 进行权重(量化)、混合算术(尽可能使用 half 浮点精度,请参阅 NVidia 的 Apex)等。

最后但同样重要的是,我不认为对于使用矢量化操作并使用 torchscript 导出的编写良好的模型,您会看到真正显着的运行时差异因为一些通用图优化。无论您要使用 GPU、CPU、TPU,它们的版本是什么,您是仅进行推理还是训练等等,仍然有所不同。很难确定 tensorflowpytorch 相比有多快(除了两个框架中的一些众所周知的问题)。总而言之,这取决于AFAIK,并且测量结果差异很大。

顺便说一句。当谈到每个框架的优势时,它们的核心确实开始涵盖类似的东西(PyTorch 最近获得了移动支持,请参阅here)。真正的区别仍然是不同的底层方法以及每个框架必须做什么来规避这些限制。

【讨论】:

以上是关于pytorch 可以优化顺序操作(如张量流图或 JAX 的 jit)吗?的主要内容,如果未能解决你的问题,请参考以下文章

如何找出冻结的张量流图的正确输入和输出操作?

Pytorch基础-张量基本操作

Tensorflow瞎搞

colab pytorch张量操作

Pytorch中的tensor常用操作

在 Android 上使用来自冻结的张量流图的变量