SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)
Posted 鸿乃江边鸟
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)相关的知识,希望对你有一定的参考价值。
背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen")
val df = spark.range(10).agg(max(col("id")), avg(col("id")))
withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true")
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
assert(df.collect() === Array(Row(9, 4.5)))
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen
± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
± *(1) Range (0, 10, step=1, splits=2)
分析
第一阶段wholeStageCodegen
第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(partial) RangeExec
=========================================================================
-> execute()
|
doExecute() ---------> inputRDDs() -----------------> inputRDDs()
|
doCodeGen()
|
+-----------------> produce()
|
doProduce()
|
doProduceWithoutKeys() -------> produce()
|
doProduce()
|
doConsume()<------------------- consume()
|
doConsumeWithoutKeys()
|并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
doConsume() <-------- consume()
RangeExec的produce
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery
this.parent = parent
ctx.freshNamePrefix = variablePrefix
s"""
|$ctx.registerComment(s"PRODUCE: $this.simpleString(conf.maxToStringFields)")
|$doProduce(ctx)
""".stripMargin
- this.parent = parent以及ctx.freshNamePrefix = variablePrefix
设置parent 以便在做consume方法的时候能够获取到父节点的引用,这样才能调用到父节点的consume方法以便代码生成。
freshNamePrefix的设置是为了在生成对应的方法的时候,区分不同物理计划的方法,这样能防止方法名重复,避免编译代码时出错。 - ctx.registerComment
这块是给java代码加上对应的注释,默认情况下是不会加上的,因为默认spark.sql.codegen.comments 是False
protected override def doProduce(ctx: CodegenContext): String =
val numOutput = metricTerm(ctx, "numOutputRows")
val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")
val value = ctx.freshName("value")
val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
val BigInt = classOf[java.math.BigInteger].getName
// Inline mutable state since not many Range operations in a task
val taskContext = ctx.addMutableState("TaskContext", "taskContext",
v => s"$v = TaskContext.get();", forceInline = true)
val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",
v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true)
// In order to periodically update the metrics without inflicting performance penalty, this
// operator produces elements in batches. After a batch is complete, the metrics are updated
// and a new batch is started.
// In the implementation below, the code in the inner loop is producing all the values
// within a batch, while the code in the outer loop is setting batch parameters and updating
// the metrics.
// Once nextIndex == batchEnd, it's time to progress to the next batch.
val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
// How many values should still be generated by this range operator.
val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo")
// How many values should be generated in the next batch.
val nextBatchTodo = ctx.freshName("nextBatchTodo")
// The default size of a batch, which must be positive integer
val batchSize = 1000
val initRangeFuncName = ctx.addNewFunction("initRange",
s"""
| private void initRange(int idx)
| $BigInt index = $BigInt.valueOf(idx);
| $BigInt numSlice = $BigInt.valueOf($numSlicesL);
| $BigInt numElement = $BigInt.valueOf($numElements.toLongL);
| $BigInt step = $BigInt.valueOf($stepL);
| $BigInt start = $BigInt.valueOf($startL);
| long partitionEnd;
|
| $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
| if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0)
| $nextIndex = Long.MAX_VALUE;
| else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0)
| $nextIndex = Long.MIN_VALUE;
| else
| $nextIndex = st.longValue();
|
| $batchEnd = $nextIndex;
|
| $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
| .multiply(step).add(start);
| if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0)
| partitionEnd = Long.MAX_VALUE;
| else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0)
| partitionEnd = Long.MIN_VALUE;
| else
| partitionEnd = end.longValue();
|
|
| $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
| $BigInt.valueOf($nextIndex));
| $numElementsTodo = startToEnd.divide(step).longValue();
| if ($numElementsTodo < 0)
| $numElementsTodo = 0;
| else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0)
| $numElementsTodo++;
|
|
""".stripMargin)
val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val stopCheck = if (parent.needStopCheck)
s"""
|if (shouldStop())
| $nextIndex = $value + $stepL;
| $numOutput.add($localIdx + 1);
| $inputMetrics.incRecordsRead($localIdx + 1);
| return;
|
""".stripMargin
else
"// shouldStop check is eliminated"
val loopCondition = if (limitNotReachedChecks.isEmpty)
"true"
else
limitNotReachedChecks.mkString(" && ")
s"""
| // initialize Range
| if (!$initTerm)
| $initTerm = true;
| $initRangeFuncName(partitionIndex);
|
|
| while ($loopCondition)
| if ($nextIndex == $batchEnd)
| long $nextBatchTodo;
| if ($numElementsTodo > $batchSizeL)
| $nextBatchTodo = $batchSizeL;
| $numElementsTodo -= $batchSizeL;
| else
| $nextBatchTodo = $numElementsTodo;
| $numElementsTodo = 0;
| if ($nextBatchTodo == 0) break;
|
| $batchEnd += $nextBatchTodo * $stepL;
|
|
| int $localEnd = (int)(($batchEnd - $nextIndex) / $stepL);
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++)
| long $value = ((long)$localIdx * $stepL) + $nextIndex;
| $consume(ctx, Seq(ev))
| $stopCheck
|
| $nextIndex = $batchEnd;
| $numOutput.add($localEnd);
| $inputMetrics.incRecordsRead($localEnd);
| $taskContext.killTaskIfInterrupted();
|
""".stripMargin
-
val numOutput = metricTerm(ctx, “numOutputRows”)
numOutput指标,用于记录输出的记录条数 -
val initTerm =以及val nextIndex =
initTerm用于标识该物理计划是够已经生成了代码,
nextIndex是用来产生rangeExec数据的逻辑索引,遍历数据
这两个参数也是类的成员变量,即全局变量 -
val value =和val ev =
这个ev值是用来表示rangExec生成的数据的,最终会被*consume(ctx, Seq(ev)方法所调用
而其中的value变量则是会在long v a l u e = ( ( l o n g ) value = ((long) value=((long)localIdx * $stepL) + $nextIndex;*被赋值,这样父节点才能进行消费 -
val taskContext =和val inputMetrics =
taskContext和inputMetrics也是全部变量,而且还有初始化变量,这种初始化方法将会在生成的类方法init中进行初始化,会形成一下代码:range_taskContext_0 = TaskContext.get(); range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
之所以会在init方法进行初始化是因为该初始化方法会被放入到mutableStateInitCodeArray类型的变量中,而mutableStateInitCode里的
数据,将会在WholeStageCodegenExec的*ctx.initMutableStates()*会被组装调用,被调用的代码如下:public void init(int index, scala.collection.Iterator[] inputs) partitionIndex = index; this.inputs = inputs; $ctx.initMutableStates() $ctx.initPartition()
-
val batchEnd =和val numElementsTodo
这两个变量也是生成类的成员变量,即全局变量 -
val nextBatchTodo =
这个变量是临时变量,会在遍历生成数据的时候用到 -
val initRangeFuncName =
就是RangeExec生成数据的逻辑了,每个物理计划都是不一样。这里忽略 -
最后的while ($loopCondition)
这部分就是根据每个分区的index不一样,生成不同的数据。
值得一提的是initRangeFuncName(partitionIndex)这部分中的partitionIndex变量,这个变量是生成的类的父类BufferedRowIterator中,
而partitionIndex变量的赋值也在init方法中,具体代码如下:public void init(int index, scala.collection.Iterator[] inputs) partitionIndex = index; this.inputs = inputs;
-
consume(ctx, Seq(ev))
父节点进行消费rangeExec产生的数据,接下来会继续讲解 -
numOutput和inputMetrics和taskContext
numOutput 进行输出数据的增加
inputMetrics 在taskMetrics级别数据的增加
taskContext.killTaskIfInterrupted 用来判断当前任务是不是被kill了,如果被kill了直接抛出异常
以上是关于SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(3)的主要内容,如果未能解决你的问题,请参考以下文章
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起