从零开始学深度学习编译器八,TVM的算符融合以及如何使用TVM Pass Infra自定义Pass

Posted just_sort

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了从零开始学深度学习编译器八,TVM的算符融合以及如何使用TVM Pass Infra自定义Pass相关的知识,希望对你有一定的参考价值。

0x00. 前言

上篇文章系统的介绍了TVM Pass Infra,并解析了Relay树结构以及Constant Folding Pass的具体做法。这一节,我想先补充一下TVM Pass Infra的用法,接着介绍一下TVM算符融合需要用到的支配树这个数据结构,最后再深入分析TVM中的一个非常重要的Pass即算符融合。

0x01. 如何使用TVM Pass Infra

关于TVM Pass Infra的介绍可以移步【从零开始学深度学习编译器】七,万字长文入门TVM Pass查看。这里来介绍一下TVM Pass Infra的使用方法,内容翻译自https://tvm.apache.org/docs/tutorials/dev/use_pass_infra.html,加了一些自己的理解。

随着 Relay/tir 中优化pass次数的增加,手动执行它们并维护它们的依赖关系变得棘手。 因此,我们引入了一个Pass基础设施来管理优化passes,并使其适用于 TVM 栈中不同层的 IR。

Relay/tir 程序的优化Pass可以应用于各种粒度,即分别使用 tvm.relay.transform.FunctionPass/tvm.tir.transform.PrimFuncPasstvm.transform.ModulePass 的function-level和module-level级别的优化pass。 或者用户可以依靠 tvm.transform.Sequential 在 Relay/tir 程序上应用一系列passes,其中passes之间的依赖关系可以通过Pass Infra解决。

这里主要是来演示一些开发人员如何使用Pass Infra来进行某种优化,并为Relay程序创建优化管道。这里的方法同样适用于tir。首先导入一些必要的包。

import numpy as np
import tvm
from tvm import te
import tvm.relay as relay

接下来,展示了一个简单的Relay程序,该程序将用于执行各种实例Pass的例子。同样,用户也可以编写一个tir原始函数并应用Pass。

创建一个Relay 程序示例

def example():
    shape = (1, 64, 54, 54)
    c_data = np.empty(shape).astype("float32")
    c = relay.const(c_data)
    weight = relay.var("weight", shape=(64, 64, 3, 3))
    x = relay.var("x", relay.TensorType((1, 64, 56, 56), "float32"))
    conv = relay.nn.conv2d(x, weight)
    y = relay.add(c, c)
    y = relay.multiply(y, relay.const(2, "float32"))
    y = relay.add(conv, y)
    z = relay.add(y, c)
    z1 = relay.add(y, c)
    z2 = relay.add(z, z1)
    return relay.Function([x, weight], z2)

然后这里给一个conv op注册一个输出数据排布更改的Pass,这个Pass将卷积层的NCHW数据排布变化成NCHW16c的数据排布。

@relay.op.register_alter_op_layout("nn.conv2d", level=101)
def alter_conv2d(attrs, inputs, tinfos, out_type):
    data, weight = inputs
    new_attrs = dict(attrs)
    new_attrs["data_layout"] = "NCHW16c"
    return relay.nn.conv2d(data, weight, **new_attrs)

优化程序

在应用Pass之前我们看一下Relay程序长什么样:

def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) {
  %0 = add(meta[relay.Constant][0], meta[relay.Constant][0]);
  %1 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0]);
  %2 = multiply(%0, 2f);
  %3 = add(%1, %2);
  %4 = add(%3, meta[relay.Constant][0]);
  %5 = add(%3, meta[relay.Constant][0]);
  add(%4, %5)
}

现在我们要优化程序。 Relay 具有许多优化功能。 我们将选择其中的一些应用到这个示例程序中。

手动应用优化Passes,这里使用一个FoldConstant的Pass。

# Let's first create a relay Module which contains one or multiple Relay
# functions for optimization.
f = example()
mod = tvm.IRModule.from_expr(f)

# Now we can apply constant folding on the module.
# fold_const here is a callback that doesn't take any parameters.
fold_const = relay.transform.FoldConstant()
# Then, we can invoke the pass on the given module. Note that the constant
# folding pass works at the function-level. That being said, each function in
# the module will be applied with the optimization. Users don't need to iterate
# through individual functions manually to apply this pass.
mod = fold_const(mod)
# We can see from the updated program that the constants are folded.
print(mod)

应用了FoldConstant Pass之后Relay程序长这样:

def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) -> Tensor[(1, 64, 54, 54), float32] {
  %0 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %1 = add(%0, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %3 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */ /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  add(%2, %3) /* ty=Tensor[(1, 64, 54, 54), float32] */
}

可以看到相对于优化之前的IR,应用了FoldConstant Pass之后初始IR的%2 = multiply(%0, 2f);由于是一个常量直接被折叠起来变成了relay.Constant][1]。接下来可以以类似的方式应用更多优化。 例如,我们可以消除 z 和 z1 使用的公共表达式,即使用EliminateCommonSubexpr Pass。

mod = relay.transform.EliminateCommonSubexpr()(mod)
print(mod)

看下面的图就很清晰了。

公共表达式消除Pass

一些优化,例如fuse,也是带一些配置参数的。 例如,opt_level 0 将不允许运算融合在一起。 用户可以通过fuse_opt_level来启用它。

mod = relay.transform.FuseOps(fuse_opt_level=0)(mod)

# We can observe that the optimized module contains functions that only have
# a signle primitive op.
print(mod)

这样IR就会是下面展示的样子:

def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) -> Tensor[(1, 64, 54, 54), float32] {
  %0 = fn (%p03: Tensor[(1, 64, 56, 56), float32], %p12: Tensor[(64, 64, 3, 3), float32], Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {
    nn.conv2d(%p03, %p12, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 64, 54, 54), float32] */
  };
  %1 = %0(%x, %weight) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %2 = fn (%p02: Tensor[(1, 64, 54, 54), float32], %p11: Tensor[(1, 64, 54, 54), float32], Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {
    add(%p02, %p11) /* ty=Tensor[(1, 64, 54, 54), float32] */
  };
  %3 = %2(%1, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %4 = fn (%p01: Tensor[(1, 64, 54, 54), float32], %p1: Tensor[(1, 64, 54, 54), float32], Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {
    add(%p01, %p1) /* ty=Tensor[(1, 64, 54, 54), float32] */
  };
  %5 = %4(%3, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;
  %6 = fn (%p0: Tensor[(1, 64, 54, 54), float32], Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {
    add(%p0, %p0) /* ty=Tensor[(1, 64, 54, 54), float32] */
  };
  %6(%5) /* ty=Tensor[(1, 64, 54, 54), float32] */
}

使用 Sequential 应用一系列Pass

像上面这样应用pass实际上很麻烦,它可能需要用户更好地理解它们之间的依赖关系。 例如,目前 fusion 在 let bindings上效果不佳。 因此,如果在融合之前应用 relay.transform.ToANormalForm() ,我们将无法融合可融合的运算符,因为此Pass为每个表达式生成 let bindings以规范 Relay 程序。

因此,Relay 提供了 tvm.transform.Sequential,通过指定每个Pass所需的passes并将它们打包为一个整体来执行,从而减轻开发人员明确处理这些问题的负担。 例如,现在可以使用sequential 样式应用相同的passes,如下所示。 tvm.transform.Sequential 类似于 torch.nn.sequentialmxnet.gluon.block。 例如,torch.nn.sequential 用于包含将被添加以构建网络的一系列 PyTorch Module,它侧重于网络层。 相反,我们的Pass Infra中的 tvm.transform.Sequential 用于优化Pass。

# Now let's execute some passes through :py:class:`tvm.transform.Sequential`
f = example()
mod = tvm.IRModule.from_expr(f)
# Glob the interested passes.
seq = tvm.transform.Sequential(
    [
        relay.transform.FoldConstant(),
        relay.transform.EliminateCommonSubexpr(),
        relay.transform.FuseOps(fuse_opt_level=2),
    ]
)
mod1 = seq(mod)
print(mod1)

输出:

def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) -> Tensor[(1, 64, 54, 54), float32] {
  %4 = fn (%p0: Tensor[(1, 64, 56, 56), float32], %p1: Tensor[(64, 64, 3, 3), float32], %p2: Tensor[(1, 64, 54, 54), float32], %p3: Tensor[(1, 64, 54, 54), float32], Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {
    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 64, 54, 54), float32] */;
    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;
    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;
    %3 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;
    add(%2, %3) /* ty=Tensor[(1, 64, 54, 54), float32] */
  };
  %4(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54<

以上是关于从零开始学深度学习编译器八,TVM的算符融合以及如何使用TVM Pass Infra自定义Pass的主要内容,如果未能解决你的问题,请参考以下文章

从零开始学深度学习编译器五,TVM Relay以及Pass简介

从零开始学深度学习编译器六,TVM的编译流程详解

从零开始学深度学习编译器七,万字长文入门TVM Pass

从零开始学深度学习编译器番外二,在Jetson Nano上玩TVM

从零开始学深度学习编译器九,TVM的CodeGen流程

从零开始学深度学习编译器十八,MLIR中的Interfaces