SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

Posted 鸿乃江边鸟

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起相关的知识,希望对你有一定的参考价值。

背景

本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,

  test("SortAggregate should be included in WholeStageCodegen") 
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") 
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    
  

该sql形成的执行计划第一部分的全代码生成部分如下:

   WholeStageCodegen     
   +- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
      +- *(1) Range (0, 10, step=1, splits=2)    

分析

第一阶段wholeStageCodegen

第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:

 WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
  =========================================================================
 
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
      |
   doCodeGen()
      |
      +----------------->   produce()
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume()<------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

SortAggregateExec(Partial)的Consume方法

此方法是由doProduceWithoutKeys方法调用的,代码如下:

 s"""
       |while (!$initAgg) 
       |  $initAgg = true;
       |  $doAggWithRecordMetric
       |
       |  // output the result
       |  $genResult.trim
       |
       |  $numOutput.add(1);
       |  $consume(ctx, resultVars).trim
       |
     """.stripMargin

其中resultVars的值为flatBufVars,即全局的sortAgg_bufValue_1sortAgg_bufValue_2变量
SPARK中的wholeStageCodegen全代码生成–以aggregate代码生成为例说起(5)中我们提到在对应的函数计算完后,sortAgg_bufValue_1sortAgg_bufValue_2会被赋值为计算的结果,如下:

    sortAgg_bufIsNull_1 = sortAgg_isNull_7;
    sortAgg_bufValue_1 = sortAgg_value_7;
    sortAgg_bufIsNull_2 = sortAgg_isNull_13;
    sortAgg_bufValue_2 = sortAgg_value_13;

所以 resultVars是已经计算处理的结果了。
这里的consume的方法已经说过了,
不同的是:

  1. SortAggregateExec(Partial)的outout是max,sum,count
  2. val rowVar = prepareRowVar(ctx, row, outputVars)返回的是包含了 max,sum,count的UnsafeRow,如下:
ExprCode(range_mutableStateArray_0[2].reset();
range_mutableStateArray_0[2].zeroOutNullBytes();
if (sortAgg_bufIsNull_0) 
  range_mutableStateArray_0[2].setNullAt(0);
 else 
  range_mutableStateArray_0[2].write(0, sortAgg_bufValue_0);

if (sortAgg_bufIsNull_1) 
  range_mutableStateArray_0[2].setNullAt(1);
 else 
  range_mutableStateArray_0[2].write(1, sortAgg_bufValue_1);

if (sortAgg_bufIsNull_2) 
  range_mutableStateArray_0[2].setNullAt(2);
 else 
  range_mutableStateArray_0[2].write(2, sortAgg_bufValue_2);
,false,(range_mutableStateArray_0[2].getRow()))
  1. val requireAllOutput = output.forall(parent.usedInputs.contains(_)) 返回的是false
    所以数据流直接到了parent.doConsume(ctx, inputVars, rowVar)

WholeStageCodegenExec的doConsume

 override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = 
    val doCopy = if (needCopyResult) 
      ".copy()"
     else 
      ""
    
    s"""
      |$row.code
      |append($row.value$doCopy);
     """.stripMargin.trim
  

其中 input为 Seq(max,sum,count), row包含了 max,sum,count的UnsafeRow

  • val doCopy =
    因为 needCopyResult返回的是children.head.asInstanceOf[CodegenSupport].needCopyResult,对应的是SortAggregateExecneedCopyResultfalse

  • $row.code
    代码组装,直接如下:

     range_mutableStateArray_0[2].reset();
     range_mutableStateArray_0[2].zeroOutNullBytes();
     if (sortAgg_bufIsNull_0) 
     range_mutableStateArray_0[2].setNullAt(0);
      else 
     range_mutableStateArray_0[2].write(0, sortAgg_bufValue_0);
     
     if (sortAgg_bufIsNull_1) 
     range_mutableStateArray_0[2].setNullAt(1);
      else 
     range_mutableStateArray_0[2].write(1, sortAgg_bufValue_1);
     
     if (sortAgg_bufIsNull_2) 
     range_mutableStateArray_0[2].setNullAt(2);
      else 
     range_mutableStateArray_0[2].write(2, sortAgg_bufValue_2);
     append((range_mutableStateArray_0[2].getRow()));
    

WholeStageCodegenExec的doCodeGen

具体的代码如下:

def doCodeGen(): (CodegenContext, CodeAndComment) = 
    val startTime = System.nanoTime()
    val ctx = new CodegenContext
    val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)

    // main next function.
    ctx.addNewFunction("processNext",
      s"""
        protected void processNext() throws java.io.IOException 
          $code.trim
        
       """, inlineToOuterClass = true)

    val className = generatedClassName()

    val source = s"""
      public Object generate(Object[] references) 
        return new $className(references);
      

      $ctx.registerComment(
        s"""Codegened pipeline for stage (id=$codegenStageId)
           |$this.treeString.trim""".stripMargin,
         "wsc_codegenPipeline")
      $ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)
      final class $className extends $classOf[BufferedRowIterator].getName 

        private Object[] references;
        private scala.collection.Iterator[] inputs;
        $ctx.declareMutableStates()

        public $className(Object[] references) 
          this.references = references;
        

        public void init(int index, scala.collection.Iterator[] inputs) 
          partitionIndex = index;
          this.inputs = inputs;
          $ctx.initMutableStates()
          $ctx.initPartition()
        

        $ctx.emitExtraCode()

        $ctx.declareAddedFunctions()
      
      """.trim

    // try to compile, helpful for debug
    val cleanedSource = CodeFormatter.stripOverlappingComments(
      new CodeAndComment(CodeFormatter.stripExtraNewLines(source), ctx.getPlaceHolderToComments()))

    val duration = System.nanoTime() - startTime
    WholeStageCodegenExec.increaseCodeGenTime(duration)

    logDebug(s"\\n$CodeFormatter.format(cleanedSource)")
    (ctx, cleanedSource)
  

  • val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
    code就是我们生成的代码逻辑,
  • ctx.addNewFunction
    code的代码会被processNext包装起来
  • val className = generatedClassName()
    对应的类名
  • val source =
    这里面的ctx.declareMutableStates,ctx.initMutableStates()等,都是在代码生成过程中,引用到的变量,在这里进行声明或者初始化
  • (ctx, cleanedSource)
    返回生成的代码

第一阶段wholeStageCodegen最终代码

/* 001 */ public Object generate(Object[] references) 
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ 
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator 
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean sortAgg_initAgg_0;
/* 010 */   private boolean sortAgg_bufIsNull_0;
/* 011 */   private long sortAgg_bufValue_0;
/* 012 */   private boolean sortAgg_bufIsNull_1;
/* 013 */   private double sortAgg_bufValue_1;
/* 014 */   private boolean sortAgg_bufIsNull_2;
/* 015 */   private long sortAgg_bufValue_2;
/* 016 */   private boolean range_initRange_0;
/* 017 */   private long range_nextIndex_0;
/* 018 */   private TaskContext range_taskContext_0;
/* 019 */   private InputMetrics range_inputMetrics_0;
/* 020 */   private long range_batchEnd_0;
/* 021 */   private long range_numElementsTodo_0;
/* 022 */   private boolean sortAgg_sortAgg_isNull_4_0;
/* 023 */   private boolean sortAgg_sortAgg_isNull_9_0;
/* 024 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 025 */
/* 026 */   public GeneratedIteratorForCodegenStage1(Object[] references) 
/* 027 */     this.references = references;
/* 028 */   
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) 
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */
/* 034 */     range_taskContext_0 = TaskContext.get();
/* 035 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 036 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 037 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 038 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 0);
/* 039 */
/* 040 */   
/* 041 */
/* 042 */   private void sortAgg_doAggregate_max_0(long sortAgg_expr_0_0) throws java.io.IOException 
/* 043 */     sortAgg_sortAgg_isNull_4_0 = true;
/* 044 */     long sortAgg_value_4 = -1L;
/* 045 */
/* 046 */     if (!sortAgg_bufIsNull_0 && (sortAgg_sortAgg_isNull_4_0 ||
/* 047 */         sortAgg_bufValue_0 > sortAgg_value_4)) 
/* 048 */       sortAgg_sortAgg_isNull_4_0 = false;
/* 049 */       sortAgg_value_4 = sortAgg_bufValue_0;
/* 050 */     
/* 051 */
/* 052 */     if (!false && (sortAgg_sortAgg_isNull_4_0 ||
/* 053 */         sortAgg_expr_0_0 > sortAgg_value_4)) 
/* 054 */       sortAgg_sortAgg_isNull_4_0 = false;
/* 055 */       sortAgg_value_4 = sortAgg_expr_0_0;
/* 056 */     
/* 057 */
/* 058 */     sortAgg_bufIsNull_0 = sortAgg_sortAgg_isNull_4_0;
/* 059 */     sortAgg_bufValue_0 = sortAgg_value_4;
/* 060 */   
/* 061 */
/* 062 */   private void sortAgg_doAggregateWithoutKey_0() throws java.io.IOException 
/* 063 */     // initialize aggregation buffer
/* 064 */     sortAgg_bufIsNull_0 = true;
/* 065 */     sortAgg_bufValue_0 = -1L;
/* 066 */     sortAgg_bufIsNull_1 = false;
/* 067 */     sortAgg_bufValue_1 = 0.0D;
/* 068 */     sortAgg_bufIsNull_2 = false;
/* 069 */     sortAgg_bufValue_2 = 0L;
/* 070 */
/* 071 */     // initialize Range
/* 072 */     if (!range_initRange_0) 
/* 073 */       range_initRange_0 = true;
/* 074 */       initRange(partitionIndex);
/* 075 */     
/* 076 */
/* 077 */     while (true) 
/* 078 */       if (range_nextIndex_0 == range_batchEnd_0) 
/* 079 */         long range_nextBatchTodo_0;
/* 080 */         if (range_numElementsTodo_0 > 1000L) 
/* 081 */           range_nextBatchTodo_0 = 1000L;
/* 082 */           range_numElementsTodo_0 -= 1000L;
/* 083 */          else 
/* 084 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 085 */           range_numElementsTodo_0 = 0;
/* 086 */           if (range_nextBatchTodo_0 == 0) break;
/* 087 */         
/* 088 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 089 */       
/* 090 */
/* 091 */       int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 092 */       for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) 
/* 093 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 094 */
/* 095 */         sortAgg_doConsume_0(range_value_0);
/* 096 */
/* 097 */         // shouldStop check is eliminated
/* 098 */       
/* 099 */       range_nextIndex_0 = range_batchEnd_0;
/* 100 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 101 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 102 */       range_taskContext_0.killTaskIfInterrupted();
/* 103 */     
/* 104 */
/* 105 */   
/* 106 */
/* 107 */   private void initRange(int idx) 
/* 108 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 109 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 110 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(10L);
/* 111 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 112 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 113 */     long partitionEnd;
/* 114 */
/* 115 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 116 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) 
/* 117 */       range_nextIndex_0 = Long.MAX_VALUE;
/* 118 */      else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) 
/* 119 */       range_nextIndex_0 = Long.MIN_VALUE;
/* 120 */      else 
/* 121 */       range_nextIndex_0 = st.longValue();
/* 122 */     
/* 123 */     range_batchEnd_0 = range_nextIndex_0;
/* 124 */
/* 125 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 126 */     .multiply(step).add(start);
/* 127 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) 
/* 128 */       partitionEnd = Long.MAX_VALUE;
/* 129 */      else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) 
/* 130 */       partitionEnd = Long.MIN_VALUE;
/* 131 */      else 
/* 132 */       partitionEnd = end.longValue();
/* 133 */     
/* 134 */
/* 135 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 136 */       java.math.BigInteger.valueOf(range_nextIndex_0));
/* 137 */     range_numElementsTodo_0  = startToEnd.divide(step).longValue();
/* 138 */     if (range_numElementsTodo_0 < 0) 
/* 139 */       range_numElementsTodo_0 = 0;
/* 140 */      else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) 
/* 141 */       range_numElementsTodo_0++;
/* 142 */     
/* 143 */   
/* 144 */
/* 145 */   protected void processNext() throws java.io.IOException 
/* 146 */     while (!sortAgg_initAgg_0) 
/* 147 */       sortAgg_initAgg_0 = true;
/* 148 */       sortAgg_doAggregateWithoutKey_0();
/* 149 */
/* 150 */       // output the result
/* 151 */
/* 152 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* numOutputRows */).add(1);
/* 153 */       range_mutableStateArray_0[2].reset();
/* 154 */
/* 155 */       range_mutableStateArray_0[2].zeroOutNullBytes();
/* 156 */
/* 157 */       if (sortAgg_bufIsNull_0) 
/* 158 */         range_mutableStateArray_0[2].setNullAt(0);
/* 159 */        else 
/* 160 */         range_mutableStateArray_0[2].write(0, sortAgg_bufValue_0);
/* 161 */       
/* 162 */
/* 163 */       if (sortAgg_bufIsNull_1) 
/* 164 */         range_mutableStateArray_0[2].setNullAt(1);
/* 165 */        else 
/* 166 */         range_mutableStateArray_0[2].write(1, sortAgg_bufValue_1);
/* 167 */       
/* 168 */
/* 169 */       if (sortAgg_bufIsNull_2) 
/* 170 */         range_mutableStateArray_0[2].setNullAt(2);
/* 171 */        else 
/* 172 */         range_mutableStateArray_0[2].write(2, sortAgg_bufValue_2);
/* 173 */       
/* 174 */       append((range_mutableStateArray_0[2].getRow()));
/* 175 */     
/* 176 */   
/* 177 */
/* 178 */   private void sortAgg_doConsume_0(long sortAgg_expr_0_0) throws java.io.IOException 
/* 179 */     // do aggregate
/* 180 */     // common sub-expressions
/* 181 */
/* 182 */     // evaluate aggregate functions and update aggregation buffers
/* 183 */     sortAgg_doAggregate_max_0(sortAgg_expr_0_0);
/* 184 */     sortAgg_doAggregate_avg_0(sortAgg_expr_0_0);
/* 185 */
/* 186 */   
/* 187 */
/* 188 */   private void sortAgg_doAggregate_avg_0(long sortAgg_expr_0_0) throws java.io.IOException 
/* 189 */     boolean sortAgg_isNull_7 = true;
/* 190 */     double sortAgg_value_7 = -1.0;
/* 191 */
/* 192 */     if (!sortAgg_bufIsNull_1) 
/* 193 */       sortAgg_sortAgg_isNull_9_0 = true;
/* 194 */       double sortAgg_value_9 = -1.0;
/* 195 */       do 
/* 196 */         boolean sortAgg_isNull_10 = false;
/* 197 */         double sortAgg_value_10 = -1.0;
/* 198 */         if (!false) 
/* 199 */           sortAgg_value_10 = (double) sortAgg_expr_0_0;
/* 200 */         
/* 201 */         if (!sortAgg_isNull_10) 
/* 202 */           sortAgg_sortAgg_isNull_9_0 = false;
/* 203 */           sortAgg_value_9 = sortAgg_value_10;
/* 204 */           continue;
/* 205 */         
/* 206 */
/* 207 */         if (!false) 
/* 208 */           sortAgg_sortAgg_isNull_9_0 = false;
/* 209 */           sortAgg_value_9 = 0.0D;
/* 210 */           continue;
/* 211 */         
/* 212 */
/* 213 */        while (false);
/* 214 */
/* 215 */       sortAgg_isNull_7 = false; // resultCode could change nullability.
/* 216 */
/* 217 */       sortAgg_value_7 = sortAgg_bufValue_1 + sortAgg_value_9;
/* 218 */
/* 219 */     
/* 220 */     boolean sortAgg_isNull_13 = false;
/* 221 */     long sortAgg_value_13 = -1L;
/* 222 */     if (!false && false) 
/* 223 */       sortAgg_isNull_13 = sortAgg_bufIsNull_2;
/* 224 */       sortAgg_value_13 = sortAgg_bufValue_2;
/* 225 */      else 
/* 226 */       boolean sortAgg_isNull_17 = true;
/* 227 */       long sortAgg_value_17 = -1L;
/* 228 */
/* 229 */       if (!sortAgg_bufIsNull_2) 
/* 230 */         sortAgg_isNull_17 = false; // resultCode could change nullability.
/* 231 */
/* 232 */         sortAgg_value_17 = sortAgg_bufValue_2 + 1L;
/* 233 */
/* 234 */       
/* 235 */       sortAgg_isNull_13 = sortAgg_isNull_17;
/* 236 */       sortAgg_value_13 = sortAgg_value_17;
/* 237 */     
/* 238 */
/* 239 */     sortAgg_bufIsNull_1 = sortAgg_isNull_7;
/* 240 */     sortAgg_bufValue_1 = sortAgg_value_7;
/* 241 */
/* 242 */     sortAgg_bufIsNull_2 = sortAgg_isNull_13;
/* 243 */     sortAgg_bufValue_2 = sortAgg_value_13;
/* 244 */   
/* 245 */
/* 246 */ 

以上是关于SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起的主要内容,如果未能解决你的问题,请参考以下文章

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起

SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起