SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

Posted 鸿乃江边鸟

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起相关的知识,希望对你有一定的参考价值。

背景

本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,

  test("SortAggregate should be included in WholeStageCodegen") 
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") 
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    
  

该sql形成的执行计划第二部分的全代码生成部分如下:

WholeStageCodegen
*(2) SortAggregate(key=[], functions=[max(id#0L), avg(id#0L)], output=[max(id)#5L, avg(id)#6])
   InputAdapter
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#13]

分析

第二阶段wholeStageCodegen

第二阶段的代码生成涉及到SortAggregateExec和ShuffleExchangeExec以及InputAdapter的produce和consume方法,这里一一来分析:
第二阶段wholeStageCodegen数据流如下:

 WholeStageCodegenExec      SortAggregateExec(Final)      InputAdapter       ShuffleExchangeExec        
  ====================================================================================
 
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() -------> execute()
      |                                                                            |
   doCodeGen()                                                                  doExecute()     
      |                                                                            |
      +----------------->   produce()                                           ShuffledRowRDD
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume() <------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

SortAggregateExec(Final) 的doProduce

这里只列出和SortAggregateExec(Partial)的不同的部分:

    val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) 
      // evaluate aggregate results
      ctx.currentVars = flatBufVars
      val aggResults = bindReferences(
        functions.map(_.evaluateExpression),
        aggregateBufferAttributes).map(_.genCode(ctx))
      val evaluateAggResults = evaluateVariables(aggResults)
      // evaluate result expressions
      ctx.currentVars = aggResults
      val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
      (resultVars,
        s"""
           |$evaluateAggResults
           |$evaluateVariables(resultVars)
         """.stripMargin)
  • 因为我们这里是Final部分,所以我们的数据流和Partial是不同的
  • ctx.currentVars = flatBufVars
    赋值currentVars为当前buffer变量,便于下面进行数据绑定,该buffer变量是全局变量
  • val aggResults = bindReferences
    1. functions.map(_.evaluateExpression) 这是对最终输出结果的计算,对于SUM来说是Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) ,生成的代码如下:
       boolean sortAgg_isNull_6 = sortAgg_bufIsNull_2;
       double sortAgg_value_6 = -1.0;
       if (!sortAgg_bufIsNull_2) 
         sortAgg_value_6 = (double) sortAgg_bufValue_2;
       
       boolean sortAgg_isNull_4 = false;
       double sortAgg_value_4 = -1.0;
       if (sortAgg_isNull_6 || sortAgg_value_6 == 0) 
         sortAgg_isNull_4 = true;
        else 
         if (sortAgg_bufIsNull_1) 
           sortAgg_isNull_4 = true;
          else 
           sortAgg_value_4 = (double)(sortAgg_bufValue_1 / sortAgg_value_6);
         
       
    
    1. aggregateBufferAttributes 聚合函数的buffer属性值 sum :: count :: Nil
      这样在绑定数据的变量数据的时候和currentVars是一一对应的
  • val evaluateAggResults = evaluateVariables(aggResults)
    对聚合的结果进行最终的计算
  • ctx.currentVars = aggResults
    把最终结果的变量赋值给currentVars,便于后面的数据绑定
  • val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
    这一步是把聚合结果的变量绑定到聚合表达式中,
    其中resultExpressionsList( avg(id#0L)#3 AS avg(id)#6) (这里我们只考虑AVG)
    aggregateAttributesresultExpressionAttributeReference的一种表达,便于在BoundReference的时候进行映射绑定
    对应的ExprCode为ExprCode(,sortAgg_isNull_4,sortAgg_value_4))

InputAdaptor的 doProduce

InputAdaptor的主要作用是承上启下,用来适配不支持Codegen的物理计划,sql如下:

  override def doProduce(ctx: CodegenContext): String = 
 // Inline mutable state since an InputRDDCodegen is used once in a task for WholeStageCodegen
 val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
   forceInline = true)
 val row = ctx.freshName("row")

 val outputVars = if (createUnsafeProjection) 
   // creating the vars will make the parent consume add an unsafe projection.
   ctx.INPUT_ROW = row
   ctx.currentVars = null
   output.zipWithIndex.map  case (a, i) =>
     BoundReference(i, a.dataType, a.nullable).genCode(ctx)
   
  else 
   null
 

 val updateNumOutputRowsMetrics = if (metrics.contains("numOutputRows")) 
   val numOutputRows = metricTerm(ctx, "numOutputRows")
   s"$numOutputRows.add(1);"
  else 
   ""
 
 s"""
    | while ($limitNotReachedCond $input.hasNext()) 
    |   InternalRow $row = (InternalRow) $input.next();
    |   $updateNumOutputRowsMetrics
    |   $consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim
    |   $shouldStopCheckCode
    | 
  """.stripMargin

  • val input = ctx.addMutableState(“scala.collection.Iterator”, “input”, v => s"$v = inputs[0];"
    定义一个input变量用来接受sortaggregate(partial)的输出的InteralRow(unsafeRow),对应的初始化方法会在init方法中调用
  • val row = ctx.freshName(“row”)
    定义一个临时变量用来接受input中的unsafe类型的InteralRow,便于进行迭代操作
  • val outputVars = if (createUnsafeProjection)
    对于InputAdaptor来说createUnsafeProjectionfalse, 所以这块返回的是null
  • val updateNumOutputRowsMetrics =
    因为metrics不满足条件,所以这里也是返回空字符串
  • 代码组装
        s"""
       | while ($limitNotReachedCond $input.hasNext()) 
       |   InternalRow $row = (InternalRow) $input.next();
       |   $updateNumOutputRowsMetrics
       |   $consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim
       |   $shouldStopCheckCode
       | 
     """.stripMargin
    
    对输入的每一行数据进行迭代操作, 之后再调用consume方法,
    注意: 这里的consume传入的是row,是InteralRow类型,而不是在RangeExec中的Long类型的变量

InputAdaptor的 consume

我们这里只说明和之前不一样的部分,对应的sql如下:

  final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String =

注意这里的参数 outputVarsnull
rowInteralRow类型的变量

  • val inputVarsCandidate =
 val inputVarsCandidate =
   if (outputVars != null) 
     assert(outputVars.length == output.length)
     // outputVars will be used to generate the code for UnsafeRow, so we should copy them
     outputVars.map(_.copy())
    else 
     assert(row != null, "outputVars and row cannot both be null.")
     ctx.currentVars = null
     ctx.INPUT_ROW = row
     output.zipWithIndex.map  case (attr, i) =>
       BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
     
  

这里的数据流向了 else :

  • ctx.INPUT_ROW = row
    设置当前的INPUT_ROWrow
    BoundReferencedoGenCode方法也是走向了另一个分支:
   assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
   val javaType = JavaCode.javaType(dataType)
   val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
   if (nullable) 
     ev.copy(code =
       code"""
          |boolean $ev.isNull = $ctx.INPUT_ROW.isNullAt($ordinal);
          |$javaType $ev.value = $ev.isNull ?
          |  $CodeGenerator.defaultValue(dataType) : ($value);
        """.stripMargin)
    else 
     ev.copy(code = code"$javaType $ev.value = $value;", isNull = FalseLiteral)
   
  • 分析
    • val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType,ordinal.toString)
      根据数据类型的不同,调用UnsafeRow的不同方法

    • if (nullable)
      因为AttributeReference("sum", sumDataType)()AttributeReference("count", LongType)()表达式 nullableTRUE,所以生成的代码为:

      boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);
      long inputadapter_value_0 = inputadapter_isNull_0 ?
      -1L : (inputadapter_row_0.getLong(0));
      boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1);
      double inputadapter_value_1 = inputadapter_isNull_1 ?
      -1.0 : (inputadapter_row_0.getDouble(1));
      boolean inputadapter_isNull_2 = inputadapter_row_0.isNullAt(2);
      long inputadapter_value_2 = inputadapter_isNull_2 ?
      -1L : (inputadapter_row_0.getLong(2));
      
  • constructDoConsumeFunction方法中inputVarsInFunc
    这里会多一个名为inputadapter_row_0的InternalRow类型的实参

以上是关于SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起的主要内容,如果未能解决你的问题,请参考以下文章

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起