SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

Posted 鸿乃江边鸟

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起相关的知识,希望对你有一定的参考价值。

背景

本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,

  test("SortAggregate should be included in WholeStageCodegen") 
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") 
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    
  

该sql形成的执行计划第一部分的全代码生成部分如下:

   WholeStageCodegen     
   +- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
      +- *(1) Range (0, 10, step=1, splits=2)    

分析

第一阶段wholeStageCodegen

第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:

WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
  =========================================================================
 
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
      |
   doCodeGen()
      |
      +----------------->   produce()
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume()<------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

RangeExec的consume方法

final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = 
    val inputVarsCandidate =
      if (outputVars != null) 
        assert(outputVars.length == output.length)
        // outputVars will be used to generate the code for UnsafeRow, so we should copy them
        outputVars.map(_.copy())
       else 
        assert(row != null, "outputVars and row cannot both be null.")
        ctx.currentVars = null
        ctx.INPUT_ROW = row
        output.zipWithIndex.map  case (attr, i) =>
          BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
        
      

    val inputVars = inputVarsCandidate match 
      case stream: Stream[ExprCode] => stream.force
      case other => other
    

    val rowVar = prepareRowVar(ctx, row, outputVars)

    // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
    // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
    // generate code of `rowVar` manually.
    ctx.currentVars = inputVars
    ctx.INPUT_ROW = null
    ctx.freshNamePrefix = parent.variablePrefix
    val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)

    // Under certain conditions, we can put the logic to consume the rows of this operator into
    // another function. So we can prevent a generated function too long to be optimized by JIT.
    // The conditions:
    // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
    // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
    //    all variables in output (see `requireAllOutput`).
    // 3. The number of output variables must less than maximum number of parameters in Java method
    //    declaration.
    val confEnabled = conf.wholeStageSplitConsumeFuncByOperator
    val requireAllOutput = output.forall(parent.usedInputs.contains(_))
    val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)
    val consumeFunc = if (confEnabled && requireAllOutput
        && CodeGenerator.isValidParamLength(paramLength)) 
      constructDoConsumeFunction(ctx, inputVars, row)
     else 
      parent.doConsume(ctx, inputVars, rowVar)
    
    s"""
       |$ctx.registerComment(s"CONSUME: $parent.simpleString(conf.maxToStringFields)")
       |$evaluated
       |$consumeFunc
     """.stripMargin
  

其中参数outputVars为传入的rangeExc产生的value

  • val inputVarsCandidate =和val inputVars =
    对于outputVars 不为空的情况下,直接copy复制一份outputVars值作为输入的变量
    如果outputVars为空,而row不为空的情况下,则说明传入的是InteralRow类型的变量,需要调用InteralRow对应的方法获取对应的值

  • val rowVar = prepareRowVar(ctx, row, outputVars)
    这部分在RangeExec中不会用到,这里不讲解(因为rangExec这里数据流会走向constructDoConsumeFunction这里)

  • ctx.currentVars = inputVars ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix
    这里是为了对evaluateRequiredVariables方法做铺垫,因为

  • val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
    其中这里的outputRange.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
    inputVarsrange_value_0
    parent.usedInputsAttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)),和output一样,也就是Range.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
    因为inputVars的code为空,所以 evaluated对于该inputVars计算也为空

  • val confEnabled val requireAllOutput
    这里的两个条件都是 TRUE

  • val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)
    计算表达式的长度,对于LONG和DOUBLE类型长度为2,其他的为1,因为range_value_0是LONG类型,所以总的长度为3

  • val consumeFunc =confEnabled && requireAllOutput&& CodeGenerator.isValidParamLength(paramLength)
    这里的三个条件都满足,所以数据流向constructDoConsumeFunction方法,如下:

    private def constructDoConsumeFunction(
        ctx: CodegenContext,
        inputVars: Seq[ExprCode],
        row: String): String = 
      val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
      val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)
    
      val doConsume = ctx.freshName("doConsume")
      ctx.currentVars = inputVarsInFunc
      ctx.INPUT_ROW = null
    
      val doConsumeFuncName = ctx.addNewFunction(doConsume,
        s"""
           | private void $doConsume($params.mkString(", ")) throws java.io.IOException 
           |   $parent.doConsume(ctx, inputVarsInFunc, rowVar)
           | 
         """.stripMargin)
    
      s"""
         | $doConsumeFuncName($args.mkString(", "));
       """.stripMargin
    
    

    其中inputVarsrange_value_0
    rowNULL

    • val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
      构造 函数实参,形参,以及形参ExprCode变量,分别为range_value_0,long sortAgg_expr_0_0,sortAgg_expr_0_0
    • val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)
      这里是构造UnsafeRow类型的变量便于传给parent进行消费 ,其中 row为NULL,inputVarsInFuncsortAgg_expr_0_0
    private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = 
      if (row != null) 
        ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow]))
       else 
        if (colVars.nonEmpty) 
          val colExprs = output.zipWithIndex.map  case (attr, i) =>
            BoundReference(i, attr.dataType, attr.nullable)
          
          val evaluateInputs = evaluateVariables(colVars)
          // generate the code to create a UnsafeRow
          ctx.INPUT_ROW = row
          ctx.currentVars = colVars
          val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
          val code = code"""
            |$evaluateInputs
            |$ev.code
           """.stripMargin
          ExprCode(code, FalseLiteral, ev.value)
         else 
          // There are no columns
          ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow]))
        
      
    
    
    • 对于val colExprs =
      这块是针对当前物理计划的输出(output)与变量值进行绑定,对于RangeExec来说output的值为Range.getOutputAttrs,即StructType(StructField(“id”, LongType, nullable = false) :: Nil).toAttributes ,而当前rangexec的对应的变量为range_value_0

    • val evaluateInputs = evaluateVariables(colVars)
      对于不是直接赋值的变量,而是通过计算得到的变量,则需要进行提前计算,在这里不需要计算。

    • val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
      这部分是产生UnsafeRow类型的变量,这个UnsafeRow类型的变量里包含了rangExec的产生的变量rang_value_0
      里面具体的细节,这里先忽略,以后会有具体的文章分析。

    • ExprCode(code, FalseLiteral, ev.value)
      这里就返回ExprCode类型的数据结构,
      其中code如下:range_mutableStateArray_0[0].reset();range_mutableStateArray_0[0].write(0, sortAgg_expr_0_0);
      ev.value如下:range_mutableStateArray_0[0].getRow()

    • val doConsume = ctx.freshName(“doConsume”)
      构建函数的名字,这里为sortAgg_doConsume_0

    • val doConsumeFuncName =
      构造函数调用,其中主要调用的是parent.doConsume(ctx, inputVarsInFunc, rowVar)方法,
      注意:这里的rowVar在SortAggregateExec中不会被用到,但是在WholeStageCodeGenExec中会被用到

  • 最后的s"""$ctx.registerComment(s"CONSUME: $parent.simpleString(conf.maxToStringFields)")$evaluated 则是组装代码

以上是关于SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起的主要内容,如果未能解决你的问题,请参考以下文章

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起