SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)

Posted 鸿乃江边鸟

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)相关的知识,希望对你有一定的参考价值。

背景

本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,

  test("SortAggregate should be included in WholeStageCodegen") 
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") 
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    
  
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen

± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
± *(1) Range (0, 10, step=1, splits=2)


分析

第一阶段wholeStageCodegen

第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:

 WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
  =========================================================================
 
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
      |
   doCodeGen()
      |
      +----------------->   produce()
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume()<------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

RangeExec的produce

final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery 
    this.parent = parent
    ctx.freshNamePrefix = variablePrefix
    s"""
       |$ctx.registerComment(s"PRODUCE: $this.simpleString(conf.maxToStringFields)")
       |$doProduce(ctx)
     """.stripMargin
  
  • this.parent = parent以及ctx.freshNamePrefix = variablePrefix
    设置parent 以便在做consume方法的时候能够获取到父节点的引用,这样才能调用到父节点的consume方法以便代码生成。
    freshNamePrefix的设置是为了在生成对应的方法的时候,区分不同物理计划的方法,这样能防止方法名重复,避免编译代码时出错。
  • ctx.registerComment
    这块是给java代码加上对应的注释,默认情况下是不会加上的,因为默认spark.sql.codegen.commentsFalse
protected override def doProduce(ctx: CodegenContext): String = 
    val numOutput = metricTerm(ctx, "numOutputRows")

    val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
    val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")

    val value = ctx.freshName("value")
    val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
    val BigInt = classOf[java.math.BigInteger].getName

    // Inline mutable state since not many Range operations in a task
    val taskContext = ctx.addMutableState("TaskContext", "taskContext",
      v => s"$v = TaskContext.get();", forceInline = true)
    val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",
      v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true)

    // In order to periodically update the metrics without inflicting performance penalty, this
    // operator produces elements in batches. After a batch is complete, the metrics are updated
    // and a new batch is started.
    // In the implementation below, the code in the inner loop is producing all the values
    // within a batch, while the code in the outer loop is setting batch parameters and updating
    // the metrics.

    // Once nextIndex == batchEnd, it's time to progress to the next batch.
    val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")

    // How many values should still be generated by this range operator.
    val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo")

    // How many values should be generated in the next batch.
    val nextBatchTodo = ctx.freshName("nextBatchTodo")

    // The default size of a batch, which must be positive integer
    val batchSize = 1000

    val initRangeFuncName = ctx.addNewFunction("initRange",
      s"""
        | private void initRange(int idx) 
        |   $BigInt index = $BigInt.valueOf(idx);
        |   $BigInt numSlice = $BigInt.valueOf($numSlicesL);
        |   $BigInt numElement = $BigInt.valueOf($numElements.toLongL);
        |   $BigInt step = $BigInt.valueOf($stepL);
        |   $BigInt start = $BigInt.valueOf($startL);
        |   long partitionEnd;
        |
        |   $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
        |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) 
        |     $nextIndex = Long.MAX_VALUE;
        |    else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) 
        |     $nextIndex = Long.MIN_VALUE;
        |    else 
        |     $nextIndex = st.longValue();
        |   
        |   $batchEnd = $nextIndex;
        |
        |   $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
        |     .multiply(step).add(start);
        |   if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) 
        |     partitionEnd = Long.MAX_VALUE;
        |    else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) 
        |     partitionEnd = Long.MIN_VALUE;
        |    else 
        |     partitionEnd = end.longValue();
        |   
        |
        |   $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
        |     $BigInt.valueOf($nextIndex));
        |   $numElementsTodo  = startToEnd.divide(step).longValue();
        |   if ($numElementsTodo < 0) 
        |     $numElementsTodo = 0;
        |    else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) 
        |     $numElementsTodo++;
        |   
        | 
       """.stripMargin)

    val localIdx = ctx.freshName("localIdx")
    val localEnd = ctx.freshName("localEnd")
    val stopCheck = if (parent.needStopCheck) 
      s"""
         |if (shouldStop()) 
         |  $nextIndex = $value + $stepL;
         |  $numOutput.add($localIdx + 1);
         |  $inputMetrics.incRecordsRead($localIdx + 1);
         |  return;
         |
       """.stripMargin
     else 
      "// shouldStop check is eliminated"
    
    val loopCondition = if (limitNotReachedChecks.isEmpty) 
      "true"
     else 
      limitNotReachedChecks.mkString(" && ")
    
    s"""
      | // initialize Range
      | if (!$initTerm) 
      |   $initTerm = true;
      |   $initRangeFuncName(partitionIndex);
      | 
      |
      | while ($loopCondition) 
      |   if ($nextIndex == $batchEnd) 
      |     long $nextBatchTodo;
      |     if ($numElementsTodo > $batchSizeL) 
      |       $nextBatchTodo = $batchSizeL;
      |       $numElementsTodo -= $batchSizeL;
      |      else 
      |       $nextBatchTodo = $numElementsTodo;
      |       $numElementsTodo = 0;
      |       if ($nextBatchTodo == 0) break;
      |     
      |     $batchEnd += $nextBatchTodo * $stepL;
      |   
      |
      |   int $localEnd = (int)(($batchEnd - $nextIndex) / $stepL);
      |   for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) 
      |     long $value = ((long)$localIdx * $stepL) + $nextIndex;
      |     $consume(ctx, Seq(ev))
      |     $stopCheck
      |   
      |   $nextIndex = $batchEnd;
      |   $numOutput.add($localEnd);
      |   $inputMetrics.incRecordsRead($localEnd);
      |   $taskContext.killTaskIfInterrupted();
      | 
     """.stripMargin
  
  • val numOutput = metricTerm(ctx, “numOutputRows”)
    numOutput指标,用于记录输出的记录条数

  • val initTerm =以及val nextIndex =
    initTerm用于标识该物理计划是够已经生成了代码,
    nextIndex是用来产生rangeExec数据的逻辑索引,遍历数据
    这两个参数也是类的成员变量,即全局变量

  • val value =和val ev =
    这个ev值是用来表示rangExec生成的数据的,最终会被*consume(ctx, Seq(ev)方法所调用
    而其中的value变量则是会在
    long v a l u e = ( ( l o n g ) value = ((long) value=((long)localIdx * $stepL) + $nextIndex;*被赋值,这样父节点才能进行消费

  • val taskContext =和val inputMetrics =
    taskContextinputMetrics也是全部变量,而且还有初始化变量,这种初始化方法将会在生成的类方法init中进行初始化,会形成一下代码:

    range_taskContext_0 = TaskContext.get();
    range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
    

    之所以会在init方法进行初始化是因为该初始化方法会被放入到mutableStateInitCodeArray类型的变量中,而mutableStateInitCode里的
    数据,将会在WholeStageCodegenExec的*ctx.initMutableStates()*会被组装调用,被调用的代码如下:

     
     public void init(int index, scala.collection.Iterator[] inputs) 
           partitionIndex = index;
           this.inputs = inputs;
           $ctx.initMutableStates()
           $ctx.initPartition()
         
    
  • val batchEnd =和val numElementsTodo
    这两个变量也是生成类的成员变量,即全局变量

  • val nextBatchTodo =
    这个变量是临时变量,会在遍历生成数据的时候用到

  • val initRangeFuncName =
    就是RangeExec生成数据的逻辑了,每个物理计划都是不一样。这里忽略

  • 最后的while ($loopCondition)
    这部分就是根据每个分区的index不一样,生成不同的数据。
    值得一提的是initRangeFuncName(partitionIndex)这部分中的partitionIndex变量,这个变量是生成的类的父类BufferedRowIterator中,
    partitionIndex变量的赋值也在init方法中,具体代码如下:

    public void init(int index, scala.collection.Iterator[] inputs) 
            partitionIndex = index;
            this.inputs = inputs; 
    
  • consume(ctx, Seq(ev))
    父节点进行消费rangeExec产生的数据,接下来会继续讲解

  • numOutput和inputMetrics和taskContext
    numOutput 进行输出数据的增加
    inputMetrics 在taskMetrics级别数据的增加
    taskContext.killTaskIfInterrupted 用来判断当前任务是不是被kill了,如果被kill了直接抛出异常

以上是关于SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)的主要内容,如果未能解决你的问题,请参考以下文章

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起