从零开始学深度学习编译器七,万字长文入门TVM Pass

Posted just_sort

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了从零开始学深度学习编译器七,万字长文入门TVM Pass相关的知识,希望对你有一定的参考价值。

0x0. 前言

这篇文章基于TVM 0.8.0.dev版本。在【从零开始学深度学习编译器】五,TVM Relay以及Pass简介 这篇推文中已经简单介绍了Relay和Pass机制。但对Pass的基础设施(Pass Infrastructure)和Relay树结构都没有详细介绍,所以这篇文章主要介绍一下Pass Infrastructure和Relay树结构,再基于这些关键的基础知识详细了解一下Constant Folding Pass,相信读者读完这篇文章会对TVM的Pass有更深的理解,并且在阅读其它Pass和实现自定义Pass时可以很Relax。

0x1. Pass Infrastructure

首先来看Pass Infrastructure,基于官方文档进行介绍。

在讲解Pass通用的注册和运行流程前,先来介绍一下TVM的Pass Infrastructure。参考官方文档:https://tvm.apache.org/docs/dev/pass_infra.html

Relay 和 TVM IR 都包含一系列优化passes,可提高模型的性能指标,例如平均推理速度、内存占用或特定设备的功耗。 TVM有一套标准优化方法以及特定于机器学习的优化方法,包括常量折叠、死代码消除、运算符布局更改、算符融合、缓冲区处理和循环变换等。 每一个Pass都使用在traversal期间和/或之前收集的分析结果来构造ir-to-ir的pass。

然而,随着TVM的迅速发展,需要一种更系统、更有效的方法来管理这些passes。此外,一个可以管理跨TVM堆栈不同层(如Relay和tir)的passes的通用框架,为开发人员快速原型化并将实现的passes插入系统铺平了道路。

例如,许多现有的生产编译器,如 GCC 和 LLVM,都采用pass manager来有效管理passes的执行。 最初管理 pass 很简单,因为 pass 的数量很少,但成熟的编译器将包含数百个单独的 pass。 Often external users will want to have custom passes correctly scheduled without having to modify a single handcrafted pass order.

同样,现代深度学习框架,如 Pytorch 和 MXNet Gluon,也有分别通过 Sequential 和 Block 启用pass-style层构建方案的趋势。 有了这样的结构,这些现代框架能够方便地将模块/层添加到它们的容器中,并轻松地构建神经网络。

Relay pass infra 的设计很大程度上受到 LLVM 中使用的分层pass manager和流行的深度学习框架中使用的block-style容器的启发。 pass infra 的主要目标包括:

  • 实现更好的optimizer编程编排。 这允许用户灵活地定制和构建自己的优化管道。
  • 提供一种用户友好的方式来调试passes。
  • 减轻开发人员手动和分别解决passes之间的依赖关系。
  • 为开发人员简化实现新passes的难度。 例如,我们允许用户在 Python 中实现一个 pass 并让 pass infra 操纵它的执行。

The Design

我们专注于为用户提供易于扩展的功能,让用户可以快速添加新passes而不会失去向后兼容性。 该设计包含后端和前端。 前者实现了 pass infra 的主要逻辑。 后者为用户提供简单的 API 进行交互,即允许用户快速创建自己的优化管道。

C++ Backend

我们提供了一个 PassInfo 对象来包含一个pass所需的基本信息。 name 是 pass 名称,opt_level 指示将启用 pass 的优化级别, required 表示执行某个 pass 所需的 pass(更多详细信息请参见include/tvm/ir/transform.h)。 例如,在注册pass的时候(将在后面介绍),pass开发人员可以指定pass的名称、将执行的优化级别和/或所需的pass。 opt_level 可用于帮助 pass infra 识别在用户提供的优化级别下运行时是否需要执行某个 pass。 required字段可以由pass infra用来解决pass依赖关系。

class PassInfoNode : public Object {
  String name;
  int opt_level;
  Array<String> required;
};

PassContext

PassContext 带有用于优化pass的有用信息。 例如,它包含错误报告系统,因此pass的作者可以提供有关优化失败原因的注释。 PassContext 还旨在替换旧的BuildConfig,它用于帮助用户配置编译选项,包括优化级别和必需/禁用的pass等。例如,我们可能有一个配置,它在 opt_level=3 时执行所有pass,除开使用 PassContext 提供的 disabled_pass=xx禁用的一些passes 。 现在我们可以在 opt_level=3 处对所有passes进行全局处理,并排除禁用pass列表中的那些pass。

这个类是为方便用户编写Python而设计的,它的语法可以在特定的配置下执行优化。 此外,用户可以通过 PassContext::Current()以线程安全的方式获取某个程序范围内可用的context,因为ThreadLocalStore用于保存创建的pass context对象,关于ThreadLocalStore建议看这篇文章:https://zhuanlan.zhihu.com/p/61587053,TVM模仿Java中的ThreadLocalStore在C++层自己实现了用来管理线程。 稍后将提供示例以展示我们如何使用 C++ 和 Python API 来创建使用pass context的编译管道。

class PassContextNode : public Object {
 public:
  ErrorReporter err_reporter;
  int opt_level{2};
  tvm::Array<tvm::Expr> required_pass;
  tvm::Array<tvm::Expr> disabled_pass;
};

class PassContext : public NodeRef {
 public:
  TVM_DLL static PassContext Create();
  TVM_DLL static PassContext Current();
  /* Other fields are omitted. */

 private:
  // The entry of a pass context scope.
  TVM_DLL void EnterWithScope();
  // The exit of a pass context scope.
  TVM_DLL void ExitWithScope();

  // Classes to get the Python `with` like syntax.
  friend class tvm::With<PassContext>;
};

struct PassContextThreadLocalEntry {
  /*! \\brief The default pass context. */
  PassContext default_context;
  /*! \\brief The current pass context. */
  std::stack<PassContext> context_stack;
  PassContextThreadLocalEntry() {
    default_context = PassContext(make_node<PassContextNode>());
  }
};

/*! \\brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
     PassContextThreadLocalStore;

Pass Constructs

pass infra 是以分层方式设计的,它可以在不同粒度的Relay/tir 程序下工作。 引入了一个纯虚拟类 PassNode 作为不同优化pass的基础。 此类包含几个必须由子类在modules, functions, or sequences of passes实现的虚拟方法。

class PassNode : Object {
  virtual PassInfo Info() const = 0;
  virtual Module operator()(const IRModule& mod
                            const PassContext& pass_ctx) const = 0;
};

成员函数展示了一个pass应该如何实现,例如它始终在特定context下工作在 IRModule中,所有的pass都被设计在一个Module to Module的管理器中。因此,由 pass infra 控制的优化将始终更新整个module。

已经创建了几个子类来实现不同类型的优化pass,例如,function-level passes, module-level passes, and sequential passes。 每个子类本身都可以充当pass管理器。 例如,他们可以收集所需的passes并执行它们或基于给定的元数据构建依赖关系图。 它们的完整定义可以在src/relay/ir/transform.cc 和 src/ir/transform.cc 中找到。

Module-Level Passes

Module Level Passes主要用于全局和过程间优化 (IPO),类似于 LLVM 中使用的module pass。 Relay 中一些典型的 pass 需要一个模块的global picture,比如 A-normal form conversion 和 lambda lifting等,都属于这个集合。 在此级别,用户甚至可以在一个module中添加和/或删除function。

class ModulePassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  // Other members/methods are omitted
};

pass_info 维护module-level pass所需的信息。 pass_func 实现了真正的optimization。 例如,我们可能需要对module执行死代码消除。 我们可以在 pass_func 中实现算法并让它在module上运行。 然后它将删除死代码,包括module中未使用的函数。 请注意,该字段被设计为一个packed function,所以这个优化不仅可以使用C++还可以使用Python来实现。

Function-Level Passes

Function-level passes用于为给定的 Relay/tir module实现各种内部函数级优化。 它一次从module的函数列表中获取一个函数以进行优化,并生成一个重写的 Relay Functiontir PrimFunc。 大多数pass可以归入这一类,例如Relay中的常见子表达式消除和inference simplification 以及tir中的向量化和flattening storage等。

请注意,此级别的passes范围是 Relay Function或 tir PrimFunc。 因此,我们无法通过这些passes添加或删除函数,因为它们不知道全局信息。

class FunctionPassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  bool SkipFunction(const Function& func) const;
  // Other members/methods are omitted...
};

pass_info 与我们刚刚在Module pass 中描述的相同。 pass_func 需要一个函数进行优化,它还需要一个Module,因为我们可能会使用它来报告错误。 一个函数可以用“SkipOptimization”注释,以便在优化过程中被忽略。

Sequential Passes

SequentialPass 类似于 Pytorch nn.Sequential,它包含许多用于执行的passes。

class SequentialPassNode : PassNode {
  PassInfo pass_info;
  // Passes need to be executed.
  Array<Pass> passes;
  bool PassEnabled(const PassInfo& info) const;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};

目前在Relay中只有少数passes 被放入这组中。 例如,FoldScaleAxis 需要在内部调度 ForwardFoldScaleAxisBackwardFoldScaleAxis。 此外,建议先完成BackwardFoldScaleAxis。 因此,该pass是SequentialPass的理想候选者。

以下代码显示了如何调用sequential pass中的各个pass。

Module SequentialNode::operator()(const Module& module,
                                  const PassContext& pass_ctx) const {
  Module mod = module;
  for (const Pass& pass : passes) {
    ICHECK(pass.defined()) << "Found undefined pass for optimization.";
    const PassInfo& pass_info = pass->Info();
    if (!PassEnabled(pass_info))  continue;
    for (const auto& it : pass_info->required) {
      const auto* name = it.as<tvm::ir::StringImm>();
      ICHECK(name);
      mod = GetPass(name->value)(mod, pass_ctx);
    }
    mod = pass(mod, pass_ctx);
  }
  return mod;
}

在调用pass时,我们首先检查是否启用了此pass。 这是通过首先检查用户是否明确禁用该pass,然后检查它是否被用户指定为必需pass来完成的。 如果仍然不确定是否启用了此传递,则将检查其 opt_level。 只有当它的opt_level不低于pass context中配置的优化级别时,才会启用并因此执行此pass。

要执行pass,我们首先需要使用pass name在 TVM packed function注册表中已注册的pass。 这是可能的,因为每个pass都注册了一个 API 接口,我们将在后面展示。

Pass GetPass(const std::string& pass_name) {
  using tvm::runtime::Registry;
  std::string fpass_name = "relay._transform." + pass_name;
  const auto* f = Registry::Get(fpass_name);
  ICHECK(f != nullptr) << "Cannot find " << fpass_name
                      << "to create the pass " << pass_name;
  return (*f)();
}

提供了一些helper function来创建上述每种类型的Pass。 这些helper function也暴露给 Python 前端,以便用户可以方便地使用 Python API 来创建特定的 pass 对象。

Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreatePrimFuncPass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreateModulePass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

Pass Registration

我们已经介绍了不同级别pass的概念和用于编译的context。 用户可以多么轻松地注册pass是一件有意义的事。,我们以constant folding为例。 这个 pass 已经被实现来折叠 Relay Function中的常量(在 tvm/src/relay/transforms/fold_constant.cc 中找到)。

提供了一个 API 来执行 ExprExpr 的转换。

Expr FoldConstant(const Expr& expr);

为了将这个pass注册到pass infra,我们首先需要决定这个pass将在哪个级别执行。 由于常量折叠发生在单个函数上,我们应该直观地通过 CreateFunctionPass为其创建一个 FunctionPasspass_func 作为packed function返回,该函数在 IRModule 中的每个function上调用 Expr to Expr API。 {} 表示此pass不需要先决条件。 否则,pass开发人员必须识别并列出它们。

namespace transform {

Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
      return Downcast<Function>(FoldConstant(f));
  };
  return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}

TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

}  // namespace transform

为了允许其他 C++ 模块应用此pass,我们在 include/tvm/relay/transform.h中声明了一个free function,如下所示:

TVM_DLL Pass FoldConstant();

Python Frontend

python前端只需要一些简单的 APIs。 例如,我们可以为用户提供以下 APIs 来创建和执行一个 pass(完整的实现在 python/tvm/relay/transform.pypython/tvm/ir/transform.py 中提供)。 后端接收信息并决定它应该使用哪个函数来创建 Pass 对象。

PassContext

Python 前端为 PassContext 提供了一个包装器,通过覆盖 __enter____exit__ 来启用 with 语法。 为用户提供了一个 current 静态方法来获取在特定范围内使用的上下文。

@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
    def __enter__(self):
        _transform.EnterPassContext(self)
        return self

    def __exit__(self, ptype, value, trace, config):
        _transform.ExitPassContext(self)

    @staticmethod
    def current():
        """Return the current pass context."""
        return _transform.GetCurrentPassContext()

PassContext 用于配置编译选项,包括优化级别和必需/禁用的pass。 它还可以带一个配置字典,以便不同的pass可以方便地获取passed的数据,例如回退设备信息和循环展开的步数/深度等。 为了能够获取所需的配置,必须通过TVM_REGISTER_PASS_CONFIG_OPTION注册关键字。 例如,loop unrolling pass使用以下内容:

TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

更多细节请参考 src/tir/transforms/unroll_loop.cc

Pass Objects

Pass 是所有 pass 对象的基类。 这里的所有方法都只是在后端实现的简单包装器。 它们是为了用户方便地与 Python 中的基类进行交互而定义的。 在 pass 基类中只定义了一个__call__来使子类成为可调用对象,以便它们可以很容易地被调用(例如 pass_xx(arg))来执行。

@register_relay_node
class Pass(RelayNode):
   def __call__(self, mod):
       return _transform.RunPass(self, mod)

提供了一些辅助 APIs 以支持从 Python 前端轻松创建pass并让pass infra控制执行。 比如提供给用户module_passfunction_passsequential,让他们可以自定义自己的pass或者pass管道。

对于在C++后端实现的所有pass,我们分别在python/tvm/ir/transform.pypython/tvm/relay/transform.py中提供了相应的Python API。 例如,const 折叠有一个 Python API,如下所示:

def FoldConstant():
    return _transform.FoldConstant()

用户可以通过装饰器像下面这样构建一个pass:

 @relay.transform.module_pass(opt_level=2)
 def transform(mod, ctx):
    tp = relay.TensorType((10,), "float32")
    x = relay.var("x", tp)
    gv = relay.GlobalVar("abs")
    func = relay.Function([x], relay.abs(x))
    new_mod = relay.Module({gv: func})
    new_mod.update(mod)
    return new_mod

module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2

这里的transform函数向输入的module添加了一个abs 函数,但它可以是module level的任何自定义pass。 创建此 module_pass 后,用户可以将其应用于任何 Relay 模块。 例如,我们可以构建一个empty module并应用此pass来添加 abs 函数。

mod = relay.Module()
mod = module_pass(mod)

相应地,我们也为 function_pass 提供了这样的功能。 例如,一个示例function-level pass可以写成如下:

@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
   def __init__(self, new_func):
      self.new_func = new_func
      def transform_function(self, func, mod, ctx):
         # Just for demo purposes
         # Transform func to new_func
         return self.new_func

x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# Now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)

或者,用户也可以不使用装饰器直接注册pass,然后调用它。 有关如何自定义您自己的优化管道以及调试 Relay 和 tir pass 的更多示例,请参阅 use pass infra 教程(https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_infra.py)。

0x2. TVM Relay树结构

AST

摘自wiki
在计算机科学中,抽象语法树(Abstract Syntax Tree,AST),或简称语法树(Syntax tree),是源代码语法结构的一种抽象表示。它以树状的形式表现编程语言的语法结构,树上的每个节点都表示源代码中的一种结构。之所以说语法是“抽象”的,是因为这里的语法并不会表示出真实语法中出现的每个细节。比如,嵌套括号被隐含在树的结构中,并没有以节点的形式呈现;而类似于 if-condition-then 这样的条件跳转语句,可以使用带有三个分支的节点来表示。
和抽象语法树相对的是具体语法树(通常称作分析树)。一般的,在源代码的翻译和编译过程中,语法分析器创建出分析树,然后从分析树生成AST。一旦AST被创建出来,在后续的处理过程中,比如语义分析阶段,会添加一些信息。

之前在解析TVM Relay的ONNX前端的时候,已经提到在完成每个OP转换之后需要使用IRModule.from_expr将所有转换后的Relay Function包起来返回,过程如下,这里关心最后一行代码即可:

def from_onnx(self, graph, opset, get_output_expr=False):
        """基于ONNX模型构建Relay IR。

        参数
        ----------
        graph : onnx protobuf 对象
           加载进来的ONNX Graph

        opset : 操作集版本

        get_output_expr: bool
            如果设置为true,则此转换将返回每个输出表达式,而不是打包的模块。 
            将子图转换为Relay时,这可能很有用。 

        Returns
        -------
        mod : tvm.IRModule
            The returned relay module

        params : dict
            A dict of name: tvm.nd.array pairs, used as pretrained weights
        """
        self.opset = opset
        # 解析网络的输入到relay中, 又叫参数,onnx的initializer就是用来保存模型参数的
        for init_tensor in graph.initializer:
            if not init_tensor.name.strip():
                raise ValueError("Tensor's name is required.")
            # 具体实现就是先把这个TensorProto使用get_numpy函数获得值,再reshape到特定形状,再基于这个numpy构造tvm.nd.array。
            array = self._parse_array(init_tensor)
            # 前面解释过,如果设置冻结参数,则将这个参数设置为Relay中的常量OP
            if self._freeze_params:
                
                self._nodes[init_tensor.name] = _expr.const(array)
            else:
                self._params[init_tensor.name] = array
                self._nodes[init_tensor.name] = new_var(
                    init_tensor.name,
                    shape=self._params[init_tensor.name].shape,
                    dtype=self._params[init_tensor.name].dtype,
                )
        # 解析ONNX模型的输入
        for i in graph.input:
            # from onnx v0.2, Graphproto.input has type ValueInfoProto,
            #  and the name is 'i.name'
            # 获取i这个输入的名字,shape,数据类型以及shape每个维度对应的名字
            i_name, i_shape, d_type, i_shape_name = get_info(i)
            # 判断i这个输入是权重参数还是输入
            if i_name in self._params:
                # i is a param instead of input
                self._num_param += 1
                self._params[i_name] = self._params.pop(i_name)
                self._nodes[i_name] = new_var(
                    i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
                )
            # 输入节点已经在Relay IR中了就不用处理了
            elif i_name in self._nodes:
                continue
            else:
                # 真正的输入节点,依赖用户进行指定
                self._num_input += 1
                self._input_names.append(i_name)
                if i_name in self._shape:
                    i_shape = self._shape[i_name]
                else:
                    if "?" in str(i_shape):
                        warning_msg = (
                            "Input %s has unknown dimension shapes: %s. "
                            "Specifying static values may improve performance"
                            % (i_name, str(i_shape_name))
                        )
                        warnings.warn(warning_msg)
                if isinstance(self._dtype, dict):
                    dtype = self._dtype[i_name] if i_name in self._dtype else d_type
                else:
                    dtype = d_type
                self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
            self._inputs[i_name] = self._nodes[i_name]
        # Only check user inputs in the outer-most graph scope.
        if self._old_manager is None:
            assert all(
                [name in self._input_names for name in self._shape.keys()]
            ), "User specified the shape for inputs that weren't found in the graph: " + str(
                self._shape
            )
        # 获取不支持的算子列表
        convert_map = _get_convert_map(opset)
        unsupported_ops = set()
        for node in graph.node

以上是关于从零开始学深度学习编译器七,万字长文入门TVM Pass的主要内容,如果未能解决你的问题,请参考以下文章

从零开始学深度学习编译器九,TVM的CodeGen流程

从零开始学深度学习编译器六,TVM的编译流程详解

从零开始学深度学习编译器五,TVM Relay以及Pass简介

从零开始学深度学习编译器八,TVM的算符融合以及如何使用TVM Pass Infra自定义Pass

从零开始学深度学习编译器番外二,在Jetson Nano上玩TVM

从零开始学自然语言处理-十万字长文带你深入学习自然语言处理全流程