SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
Posted 鸿乃江边鸟
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起相关的知识,希望对你有一定的参考价值。
背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen")
val df = spark.range(10).agg(max(col("id")), avg(col("id")))
withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true")
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
assert(df.collect() === Array(Row(9, 4.5)))
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen
+- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
+- *(1) Range (0, 10, step=1, splits=2)
分析
第一阶段wholeStageCodegen
第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(partial) RangeExec
=========================================================================
-> execute()
|
doExecute() ---------> inputRDDs() -----------------> inputRDDs()
|
doCodeGen()
|
+-----------------> produce()
|
doProduce()
|
doProduceWithoutKeys() -------> produce()
|
doProduce()
|
doConsume()<------------------- consume()
|
doConsumeWithoutKeys()
|并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
doConsume() <-------- consume()
SortAggregateExec(Partial)的Consume方法
此方法是由doProduceWithoutKeys
方法调用的,代码如下:
s"""
|while (!$initAgg)
| $initAgg = true;
| $doAggWithRecordMetric
|
| // output the result
| $genResult.trim
|
| $numOutput.add(1);
| $consume(ctx, resultVars).trim
|
""".stripMargin
其中resultVars
的值为flatBufVars
,即全局的sortAgg_bufValue_1
和sortAgg_bufValue_2
变量
在SPARK中的wholeStageCodegen全代码生成–以aggregate代码生成为例说起(5)中我们提到在对应的函数计算完后,sortAgg_bufValue_1
和sortAgg_bufValue_2
会被赋值为计算的结果,如下:
sortAgg_bufIsNull_1 = sortAgg_isNull_7;
sortAgg_bufValue_1 = sortAgg_value_7;
sortAgg_bufIsNull_2 = sortAgg_isNull_13;
sortAgg_bufValue_2 = sortAgg_value_13;
所以 resultVars
是已经计算处理的结果了。
这里的consume
的方法已经说过了,
不同的是:
- SortAggregateExec(Partial)的outout是
max
,sum
,count
val rowVar = prepareRowVar(ctx, row, outputVars)
返回的是包含了max
,sum
,count
的UnsafeRow,如下:
ExprCode(range_mutableStateArray_0[2].reset();
range_mutableStateArray_0[2].zeroOutNullBytes();
if (sortAgg_bufIsNull_0)
range_mutableStateArray_0[2].setNullAt(0);
else
range_mutableStateArray_0[2].write(0, sortAgg_bufValue_0);
if (sortAgg_bufIsNull_1)
range_mutableStateArray_0[2].setNullAt(1);
else
range_mutableStateArray_0[2].write(1, sortAgg_bufValue_1);
if (sortAgg_bufIsNull_2)
range_mutableStateArray_0[2].setNullAt(2);
else
range_mutableStateArray_0[2].write(2, sortAgg_bufValue_2);
,false,(range_mutableStateArray_0[2].getRow()))
val requireAllOutput = output.forall(parent.usedInputs.contains(_))
返回的是false
所以数据流直接到了parent.doConsume(ctx, inputVars, rowVar)
WholeStageCodegenExec的doConsume
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String =
val doCopy = if (needCopyResult)
".copy()"
else
""
s"""
|$row.code
|append($row.value$doCopy);
""".stripMargin.trim
其中 input
为 Seq(max
,sum
,count
), row
为包含了
max,
sum,
count的UnsafeRow
-
val doCopy =
因为needCopyResult
返回的是children.head.asInstanceOf[CodegenSupport].needCopyResult
,对应的是SortAggregateExec
的needCopyResult
为false
-
$row.code
代码组装,直接如下:range_mutableStateArray_0[2].reset(); range_mutableStateArray_0[2].zeroOutNullBytes(); if (sortAgg_bufIsNull_0) range_mutableStateArray_0[2].setNullAt(0); else range_mutableStateArray_0[2].write(0, sortAgg_bufValue_0); if (sortAgg_bufIsNull_1) range_mutableStateArray_0[2].setNullAt(1); else range_mutableStateArray_0[2].write(1, sortAgg_bufValue_1); if (sortAgg_bufIsNull_2) range_mutableStateArray_0[2].setNullAt(2); else range_mutableStateArray_0[2].write(2, sortAgg_bufValue_2); append((range_mutableStateArray_0[2].getRow()));
WholeStageCodegenExec的doCodeGen
具体的代码如下:
def doCodeGen(): (CodegenContext, CodeAndComment) =
val startTime = System.nanoTime()
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
// main next function.
ctx.addNewFunction("processNext",
s"""
protected void processNext() throws java.io.IOException
$code.trim
""", inlineToOuterClass = true)
val className = generatedClassName()
val source = s"""
public Object generate(Object[] references)
return new $className(references);
$ctx.registerComment(
s"""Codegened pipeline for stage (id=$codegenStageId)
|$this.treeString.trim""".stripMargin,
"wsc_codegenPipeline")
$ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)
final class $className extends $classOf[BufferedRowIterator].getName
private Object[] references;
private scala.collection.Iterator[] inputs;
$ctx.declareMutableStates()
public $className(Object[] references)
this.references = references;
public void init(int index, scala.collection.Iterator[] inputs)
partitionIndex = index;
this.inputs = inputs;
$ctx.initMutableStates()
$ctx.initPartition()
$ctx.emitExtraCode()
$ctx.declareAddedFunctions()
""".trim
// try to compile, helpful for debug
val cleanedSource = CodeFormatter.stripOverlappingComments(
new CodeAndComment(CodeFormatter.stripExtraNewLines(source), ctx.getPlaceHolderToComments()))
val duration = System.nanoTime() - startTime
WholeStageCodegenExec.increaseCodeGenTime(duration)
logDebug(s"\\n$CodeFormatter.format(cleanedSource)")
(ctx, cleanedSource)
- val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
code就是我们生成的代码逻辑, - ctx.addNewFunction
code的代码会被processNext
包装起来 - val className = generatedClassName()
对应的类名 - val source =
这里面的ctx.declareMutableStates
,ctx.initMutableStates()
等,都是在代码生成过程中,引用到的变量,在这里进行声明或者初始化 - (ctx, cleanedSource)
返回生成的代码
第一阶段wholeStageCodegen最终代码
/* 001 */ public Object generate(Object[] references)
/* 002 */ return new GeneratedIteratorForCodegenStage1(references);
/* 003 */
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private boolean sortAgg_initAgg_0;
/* 010 */ private boolean sortAgg_bufIsNull_0;
/* 011 */ private long sortAgg_bufValue_0;
/* 012 */ private boolean sortAgg_bufIsNull_1;
/* 013 */ private double sortAgg_bufValue_1;
/* 014 */ private boolean sortAgg_bufIsNull_2;
/* 015 */ private long sortAgg_bufValue_2;
/* 016 */ private boolean range_initRange_0;
/* 017 */ private long range_nextIndex_0;
/* 018 */ private TaskContext range_taskContext_0;
/* 019 */ private InputMetrics range_inputMetrics_0;
/* 020 */ private long range_batchEnd_0;
/* 021 */ private long range_numElementsTodo_0;
/* 022 */ private boolean sortAgg_sortAgg_isNull_4_0;
/* 023 */ private boolean sortAgg_sortAgg_isNull_9_0;
/* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 025 */
/* 026 */ public GeneratedIteratorForCodegenStage1(Object[] references)
/* 027 */ this.references = references;
/* 028 */
/* 029 */
/* 030 */ public void init(int index, scala.collection.Iterator[] inputs)
/* 031 */ partitionIndex = index;
/* 032 */ this.inputs = inputs;
/* 033 */
/* 034 */ range_taskContext_0 = TaskContext.get();
/* 035 */ range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 036 */ range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 037 */ range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 038 */ range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 0);
/* 039 */
/* 040 */
/* 041 */
/* 042 */ private void sortAgg_doAggregate_max_0(long sortAgg_expr_0_0) throws java.io.IOException
/* 043 */ sortAgg_sortAgg_isNull_4_0 = true;
/* 044 */ long sortAgg_value_4 = -1L;
/* 045 */
/* 046 */ if (!sortAgg_bufIsNull_0 && (sortAgg_sortAgg_isNull_4_0 ||
/* 047 */ sortAgg_bufValue_0 > sortAgg_value_4))
/* 048 */ sortAgg_sortAgg_isNull_4_0 = false;
/* 049 */ sortAgg_value_4 = sortAgg_bufValue_0;
/* 050 */
/* 051 */
/* 052 */ if (!false && (sortAgg_sortAgg_isNull_4_0 ||
/* 053 */ sortAgg_expr_0_0 > sortAgg_value_4))
/* 054 */ sortAgg_sortAgg_isNull_4_0 = false;
/* 055 */ sortAgg_value_4 = sortAgg_expr_0_0;
/* 056 */
/* 057 */
/* 058 */ sortAgg_bufIsNull_0 = sortAgg_sortAgg_isNull_4_0;
/* 059 */ sortAgg_bufValue_0 = sortAgg_value_4;
/* 060 */
/* 061 */
/* 062 */ private void sortAgg_doAggregateWithoutKey_0() throws java.io.IOException
/* 063 */ // initialize aggregation buffer
/* 064 */ sortAgg_bufIsNull_0 = true;
/* 065 */ sortAgg_bufValue_0 = -1L;
/* 066 */ sortAgg_bufIsNull_1 = false;
/* 067 */ sortAgg_bufValue_1 = 0.0D;
/* 068 */ sortAgg_bufIsNull_2 = false;
/* 069 */ sortAgg_bufValue_2 = 0L;
/* 070 */
/* 071 */ // initialize Range
/* 072 */ if (!range_initRange_0)
/* 073 */ range_initRange_0 = true;
/* 074 */ initRange(partitionIndex);
/* 075 */
/* 076 */
/* 077 */ while (true)
/* 078 */ if (range_nextIndex_0 == range_batchEnd_0)
/* 079 */ long range_nextBatchTodo_0;
/* 080 */ if (range_numElementsTodo_0 > 1000L)
/* 081 */ range_nextBatchTodo_0 = 1000L;
/* 082 */ range_numElementsTodo_0 -= 1000L;
/* 083 */ else
/* 084 */ range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 085 */ range_numElementsTodo_0 = 0;
/* 086 */ if (range_nextBatchTodo_0 == 0) break;
/* 087 */
/* 088 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 089 */
/* 090 */
/* 091 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 092 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++)
/* 093 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 094 */
/* 095 */ sortAgg_doConsume_0(range_value_0);
/* 096 */
/* 097 */ // shouldStop check is eliminated
/* 098 */
/* 099 */ range_nextIndex_0 = range_batchEnd_0;
/* 100 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 101 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 102 */ range_taskContext_0.killTaskIfInterrupted();
/* 103 */
/* 104 */
/* 105 */
/* 106 */
/* 107 */ private void initRange(int idx)
/* 108 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 109 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 110 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10L);
/* 111 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 112 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 113 */ long partitionEnd;
/* 114 */
/* 115 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 116 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0)
/* 117 */ range_nextIndex_0 = Long.MAX_VALUE;
/* 118 */ else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0)
/* 119 */ range_nextIndex_0 = Long.MIN_VALUE;
/* 120 */ else
/* 121 */ range_nextIndex_0 = st.longValue();
/* 122 */
/* 123 */ range_batchEnd_0 = range_nextIndex_0;
/* 124 */
/* 125 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 126 */ .multiply(step).add(start);
/* 127 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0)
/* 128 */ partitionEnd = Long.MAX_VALUE;
/* 129 */ else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0)
/* 130 */ partitionEnd = Long.MIN_VALUE;
/* 131 */ else
/* 132 */ partitionEnd = end.longValue();
/* 133 */
/* 134 */
/* 135 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 136 */ java.math.BigInteger.valueOf(range_nextIndex_0));
/* 137 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue();
/* 138 */ if (range_numElementsTodo_0 < 0)
/* 139 */ range_numElementsTodo_0 = 0;
/* 140 */ else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0)
/* 141 */ range_numElementsTodo_0++;
/* 142 */
/* 143 */
/* 144 */
/* 145 */ protected void processNext() throws java.io.IOException
/* 146 */ while (!sortAgg_initAgg_0)
/* 147 */ sortAgg_initAgg_0 = true;
/* 148 */ sortAgg_doAggregateWithoutKey_0();
/* 149 */
/* 150 */ // output the result
/* 151 */
/* 152 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* numOutputRows */).add(1);
/* 153 */ range_mutableStateArray_0[2].reset();
/* 154 */
/* 155 */ range_mutableStateArray_0[2].zeroOutNullBytes();
/* 156 */
/* 157 */ if (sortAgg_bufIsNull_0)
/* 158 */ range_mutableStateArray_0[2].setNullAt(0);
/* 159 */ else
/* 160 */ range_mutableStateArray_0[2].write(0, sortAgg_bufValue_0);
/* 161 */
/* 162 */
/* 163 */ if (sortAgg_bufIsNull_1)
/* 164 */ range_mutableStateArray_0[2].setNullAt(1);
/* 165 */ else
/* 166 */ range_mutableStateArray_0[2].write(1, sortAgg_bufValue_1);
/* 167 */
/* 168 */
/* 169 */ if (sortAgg_bufIsNull_2)
/* 170 */ range_mutableStateArray_0[2].setNullAt(2);
/* 171 */ else
/* 172 */ range_mutableStateArray_0[2].write(2, sortAgg_bufValue_2);
/* 173 */
/* 174 */ append((range_mutableStateArray_0[2].getRow()));
/* 175 */
/* 176 */
/* 177 */
/* 178 */ private void sortAgg_doConsume_0(long sortAgg_expr_0_0) throws java.io.IOException
/* 179 */ // do aggregate
/* 180 */ // common sub-expressions
/* 181 */
/* 182 */ // evaluate aggregate functions and update aggregation buffers
/* 183 */ sortAgg_doAggregate_max_0(sortAgg_expr_0_0);
/* 184 */ sortAgg_doAggregate_avg_0(sortAgg_expr_0_0);
/* 185 */
/* 186 */
/* 187 */
/* 188 */ private void sortAgg_doAggregate_avg_0(long sortAgg_expr_0_0) throws java.io.IOException
/* 189 */ boolean sortAgg_isNull_7 = true;
/* 190 */ double sortAgg_value_7 = -1.0;
/* 191 */
/* 192 */ if (!sortAgg_bufIsNull_1)
/* 193 */ sortAgg_sortAgg_isNull_9_0 = true;
/* 194 */ double sortAgg_value_9 = -1.0;
/* 195 */ do
/* 196 */ boolean sortAgg_isNull_10 = false;
/* 197 */ double sortAgg_value_10 = -1.0;
/* 198 */ if (!false)
/* 199 */ sortAgg_value_10 = (double) sortAgg_expr_0_0;
/* 200 */
/* 201 */ if (!sortAgg_isNull_10)
/* 202 */ sortAgg_sortAgg_isNull_9_0 = false;
/* 203 */ sortAgg_value_9 = sortAgg_value_10;
/* 204 */ continue;
/* 205 */
/* 206 */
/* 207 */ if (!false)
/* 208 */ sortAgg_sortAgg_isNull_9_0 = false;
/* 209 */ sortAgg_value_9 = 0.0D;
/* 210 */ continue;
/* 211 */
/* 212 */
/* 213 */ while (false);
/* 214 */
/* 215 */ sortAgg_isNull_7 = false; // resultCode could change nullability.
/* 216 */
/* 217 */ sortAgg_value_7 = sortAgg_bufValue_1 + sortAgg_value_9;
/* 218 */
/* 219 */
/* 220 */ boolean sortAgg_isNull_13 = false;
/* 221 */ long sortAgg_value_13 = -1L;
/* 222 */ if (!false && false)
/* 223 */ sortAgg_isNull_13 = sortAgg_bufIsNull_2;
/* 224 */ sortAgg_value_13 = sortAgg_bufValue_2;
/* 225 */ else
/* 226 */ boolean sortAgg_isNull_17 = true;
/* 227 */ long sortAgg_value_17 = -1L;
/* 228 */
/* 229 */ if (!sortAgg_bufIsNull_2)
/* 230 */ sortAgg_isNull_17 = false; // resultCode could change nullability.
/* 231 */
/* 232 */ sortAgg_value_17 = sortAgg_bufValue_2 + 1L;
/* 233 */
/* 234 */
/* 235 */ sortAgg_isNull_13 = sortAgg_isNull_17;
/* 236 */ sortAgg_value_13 = sortAgg_value_17;
/* 237 */
/* 238 */
/* 239 */ sortAgg_bufIsNull_1 = sortAgg_isNull_7;
/* 240 */ sortAgg_bufValue_1 = sortAgg_value_7;
/* 241 */
/* 242 */ sortAgg_bufIsNull_2 = sortAgg_isNull_13;
/* 243 */ sortAgg_bufValue_2 = sortAgg_value_13;
/* 244 */
/* 245 */
/* 246 */
以上是关于SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起的主要内容,如果未能解决你的问题,请参考以下文章
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起