SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
Posted 鸿乃江边鸟
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起相关的知识,希望对你有一定的参考价值。
背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen")
val df = spark.range(10).agg(max(col("id")), avg(col("id")))
withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true")
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
assert(df.collect() === Array(Row(9, 4.5)))
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen
+- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
+- *(1) Range (0, 10, step=1, splits=2)
分析
第一阶段wholeStageCodegen
第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(partial) RangeExec
=========================================================================
-> execute()
|
doExecute() ---------> inputRDDs() -----------------> inputRDDs()
|
doCodeGen()
|
+-----------------> produce()
|
doProduce()
|
doProduceWithoutKeys() -------> produce()
|
doProduce()
|
doConsume()<------------------- consume()
|
doConsumeWithoutKeys()
|并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
doConsume() <-------- consume()
RangeExec的consume方法
final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String =
val inputVarsCandidate =
if (outputVars != null)
assert(outputVars.length == output.length)
// outputVars will be used to generate the code for UnsafeRow, so we should copy them
outputVars.map(_.copy())
else
assert(row != null, "outputVars and row cannot both be null.")
ctx.currentVars = null
ctx.INPUT_ROW = row
output.zipWithIndex.map case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
val inputVars = inputVarsCandidate match
case stream: Stream[ExprCode] => stream.force
case other => other
val rowVar = prepareRowVar(ctx, row, outputVars)
// Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
// before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
// generate code of `rowVar` manually.
ctx.currentVars = inputVars
ctx.INPUT_ROW = null
ctx.freshNamePrefix = parent.variablePrefix
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
// Under certain conditions, we can put the logic to consume the rows of this operator into
// another function. So we can prevent a generated function too long to be optimized by JIT.
// The conditions:
// 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
// 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
// all variables in output (see `requireAllOutput`).
// 3. The number of output variables must less than maximum number of parameters in Java method
// declaration.
val confEnabled = conf.wholeStageSplitConsumeFuncByOperator
val requireAllOutput = output.forall(parent.usedInputs.contains(_))
val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)
val consumeFunc = if (confEnabled && requireAllOutput
&& CodeGenerator.isValidParamLength(paramLength))
constructDoConsumeFunction(ctx, inputVars, row)
else
parent.doConsume(ctx, inputVars, rowVar)
s"""
|$ctx.registerComment(s"CONSUME: $parent.simpleString(conf.maxToStringFields)")
|$evaluated
|$consumeFunc
""".stripMargin
其中参数outputVars为传入的rangeExc产生的value
-
val inputVarsCandidate =和val inputVars =
对于outputVars 不为空的情况下,直接copy复制一份outputVars值作为输入的变量
如果outputVars为空,而row不为空的情况下,则说明传入的是InteralRow类型的变量,需要调用InteralRow对应的方法获取对应的值 -
val rowVar = prepareRowVar(ctx, row, outputVars)
这部分在RangeExec中不会用到,这里不讲解(因为rangExec这里数据流会走向constructDoConsumeFunction
这里) -
ctx.currentVars = inputVars ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix
这里是为了对evaluateRequiredVariables
方法做铺垫,因为 -
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
其中这里的output
为 Range.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
inputVars
为range_value_0
parent.usedInputs
为AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output))
,和output
一样,也就是Range.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
因为inputVars
的code为空,所以 evaluated对于该inputVars计算也为空 -
val confEnabled val requireAllOutput
这里的两个条件都是TRUE
-
val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)
计算表达式的长度,对于LONG和DOUBLE类型长度为2,其他的为1,因为range_value_0是LONG类型,所以总的长度为3 -
val consumeFunc =confEnabled && requireAllOutput&& CodeGenerator.isValidParamLength(paramLength)
这里的三个条件都满足,所以数据流向constructDoConsumeFunction方法,如下:private def constructDoConsumeFunction( ctx: CodegenContext, inputVars: Seq[ExprCode], row: String): String = val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row) val rowVar = prepareRowVar(ctx, row, inputVarsInFunc) val doConsume = ctx.freshName("doConsume") ctx.currentVars = inputVarsInFunc ctx.INPUT_ROW = null val doConsumeFuncName = ctx.addNewFunction(doConsume, s""" | private void $doConsume($params.mkString(", ")) throws java.io.IOException | $parent.doConsume(ctx, inputVarsInFunc, rowVar) | """.stripMargin) s""" | $doConsumeFuncName($args.mkString(", ")); """.stripMargin
其中
inputVars
为range_value_0
row
为NULL- val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
构造 函数实参,形参,以及形参ExprCode变量,分别为range_value_0
,long sortAgg_expr_0_0
,sortAgg_expr_0_0
- val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)
这里是构造UnsafeRow类型的变量便于传给parent进行消费 ,其中row
为NULL,inputVarsInFunc
为sortAgg_expr_0_0
private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = if (row != null) ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow])) else if (colVars.nonEmpty) val colExprs = output.zipWithIndex.map case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) val evaluateInputs = evaluateVariables(colVars) // generate the code to create a UnsafeRow ctx.INPUT_ROW = row ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) val code = code""" |$evaluateInputs |$ev.code """.stripMargin ExprCode(code, FalseLiteral, ev.value) else // There are no columns ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow]))
-
对于val colExprs =
这块是针对当前物理计划的输出(output)与变量值进行绑定,对于RangeExec来说output的值为Range.getOutputAttrs,即StructType(StructField(“id”, LongType, nullable = false) :: Nil).toAttributes ,而当前rangexec的对应的变量为range_value_0 -
val evaluateInputs = evaluateVariables(colVars)
对于不是直接赋值的变量,而是通过计算得到的变量,则需要进行提前计算,在这里不需要计算。 -
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
这部分是产生UnsafeRow类型的变量,这个UnsafeRow类型的变量里包含了rangExec的产生的变量rang_value_0
里面具体的细节,这里先忽略,以后会有具体的文章分析。 -
ExprCode(code, FalseLiteral, ev.value)
这里就返回ExprCode类型的数据结构,
其中code
如下:range_mutableStateArray_0[0].reset();range_mutableStateArray_0[0].write(0, sortAgg_expr_0_0);
ev.value
如下:range_mutableStateArray_0[0].getRow()
-
val doConsume = ctx.freshName(“doConsume”)
构建函数的名字,这里为sortAgg_doConsume_0
-
val doConsumeFuncName =
构造函数调用,其中主要调用的是parent.doConsume(ctx, inputVarsInFunc, rowVar)
方法,
注意:这里的rowVar
在SortAggregateExec中不会被用到,但是在WholeStageCodeGenExec中会被用到
- val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
-
最后的
s"""$ctx.registerComment(s"CONSUME: $parent.simpleString(conf.maxToStringFields)")$evaluated
则是组装代码
以上是关于SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起的主要内容,如果未能解决你的问题,请参考以下文章
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起