SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
Posted 鸿乃江边鸟
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起相关的知识,希望对你有一定的参考价值。
背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen")
val df = spark.range(10).agg(max(col("id")), avg(col("id")))
withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true")
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
assert(df.collect() === Array(Row(9, 4.5)))
该sql形成的执行计划第一部分的全代码生成部分如下:
WholeStageCodegen
± *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
± *(1) Range (0, 10, step=1, splits=2)
分析
第一阶段wholeStageCodegen
第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:
第一阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(partial) RangeExec
=========================================================================
-> execute()
|
doExecute() ---------> inputRDDs() -----------------> inputRDDs()
|
doCodeGen()
|
+-----------------> produce()
|
doProduce()
|
doProduceWithoutKeys() -------> produce()
|
doProduce()
|
doConsume()<------------------- consume()
|
doConsumeWithoutKeys()
|并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
doConsume() <-------- consume()
SortAggregateExec(Partial)的doConsume方法
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String =
if (groupingExpressions.isEmpty)
doConsumeWithoutKeys(ctx, input)
else
doConsumeWithKeys(ctx, input)
注意这里虽然把ExprCode类型变量row
传递进来了,但是在这个方法中却没有用到,因为对于大部分情况来说,该变量是对外部传递InteralRow的作用。
而input则是sortAgg_expr_0_0
,由rang_value_0
赋值而来.
doConsumeWithoutKeys对应的方法如下:
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String =
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes
// To individually generate code for each aggregate function, an element in `updateExprs` holds
// all the expressions for the buffer of an aggregation function.
val updateExprs = aggregateExpressions.map e =>
e.mode match
case Partial | Complete =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
case PartialMerge | Final =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
ctx.currentVars = bufVars.flatten ++ input
println(s"updateExprs: $updateExprs")
val boundUpdateExprs = updateExprs.map updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttrs)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val bufferEvals = boundUpdateExprs.map boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states)
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
val aggNames = functions.map(_.prettyName)
val aggCodeBlocks = bufferEvals.zipWithIndex.map case (bufferEvalsForOneFunc, i) =>
val bufVarsForOneFunc = bufVars(i)
// All the update code for aggregation buffers should be placed in the end
// of each aggregation function code.
println(s"bufVarsForOneFunc: $bufVarsForOneFunc")
val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map case (ev, bufVar) =>
s"""
|$bufVar.isNull = $ev.isNull;
|$bufVar.value = $ev.value;
""".stripMargin
code"""
|$ctx.registerComment(s"do aggregate for $aggNames(i)")
|$ctx.registerComment("evaluate aggregate function")
|$evaluateVariables(bufferEvalsForOneFunc)
|$ctx.registerComment("update aggregation buffers")
|$updates.mkString("\\n").trim
""".stripMargin
val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
s"""
|// do aggregate
|// common sub-expressions
|$effectiveCodes
|// evaluate aggregate functions and update aggregation buffers
|$codeToEvalAggFuncs
""".stripMargin
-
val functions =和val inputAttrs =
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes
,对于AVG聚合函数来说,聚合的缓冲属性(aggBufferAttributes)为AttributeReference("sum", sumDataType)()
和AttributeReference("count", LongType)()
.
对于当前的计划来说,SortAggregateExec的inputAttributes
为AttributeReference("id", LongType, nullable = false)()
-
val updateExprs = aggregateExpressions.
对于目前的物理计划来说,当前的mode
为Partial
,所以该值为updateExpressions
,也就是局部更新,即Add( sum, coalesce(child.cast(sumDataType), Literal.default(sumDataType)), failOnError = useAnsiAdd), /* count = */ If(child.isNull, count, count + 1L)
-
ctx.currentVars = bufVars.flatten ++ input
这里的bufVars
是在SortAggregateExec的produce方法进行赋值的,也就是对应“SUM”和“COUNT”初始值的ExprCode
这里的input
是名为sortAgg_expr_0_0
的ExprCode
变量 -
val boundUpdateExprs =
把当前的输入变量绑定到updataExprs
中去(很明显inputAttrs和currentVars是一一对应的) -
val subExprs = 和val effectiveCodes =
进行公共子表达式的消除,并提前计算出在计算子表达式计算之前的自表达式。
对于当前的计划来说,该``effectiveCodes`为空字符串. -
val bufferEvals =
产生进行update的ExprCode,这里具体为(这里分别为Add和IF表达式的codegen:List(ExprCode(boolean sortAgg_isNull_7 = true; double sortAgg_value_7 = -1.0; if (!sortAgg_bufIsNull_1) sortAgg_sortAgg_isNull_9_0 = true; double sortAgg_value_9 = -1.0; do boolean sortAgg_isNull_10 = false; double sortAgg_value_10 = -1.0; if (!false) sortAgg_value_10 = (double) sortAgg_expr_0_0; if (!sortAgg_isNull_10) sortAgg_sortAgg_isNull_9_0 = false; sortAgg_value_9 = sortAgg_value_10; continue; if (!false) sortAgg_sortAgg_isNull_9_0 = false; sortAgg_value_9 = 0.0D; continue; while (false); sortAgg_isNull_7 = false; // resultCode could change nullability. sortAgg_value_7 = sortAgg_bufValue_1 + sortAgg_value_9; ,sortAgg_isNull_7,sortAgg_value_7), ExprCode(boolean sortAgg_isNull_13 = false; long sortAgg_value_13 = -1L; if (!false && false) sortAgg_isNull_13 = sortAgg_bufIsNull_2; sortAgg_value_13 = sortAgg_bufValue_2; else boolean sortAgg_isNull_17 = true; long sortAgg_value_17 = -1L; if (!sortAgg_bufIsNull_2) sortAgg_isNull_17 = false; // resultCode could change nullability. sortAgg_value_17 = sortAgg_bufValue_2 + 1L; sortAgg_isNull_13 = sortAgg_isNull_17; sortAgg_value_13 = sortAgg_value_17; ,sortAgg_isNull_13,sortAgg_value_13))
-
val aggNames = functions.map(_.prettyName)
这里定义聚合函数的方法名字,最终会行成如下:sortAgg_doAggregate_avg_0
类似这种名字的方法。 -
val aggCodeBlocks =
这个是对应各个聚合函数的代码块,并在进行了聚合以后,把聚合的结果赋值给全局变量,对应的sql为:sortAgg_bufIsNull_1 = sortAgg_isNull_7; sortAgg_bufValue_1 = sortAgg_value_7; sortAgg_bufIsNull_2 = sortAgg_isNull_13; sortAgg_bufValue_2 = sortAgg_value_13;
其中
sortAgg_bufValue_1
代表了SUM
,sortAgg_bufValue_2
代表COUNT
。 -
val codeToEvalAggFuncs = generateEvalCodeForAggFuncs
生成各个聚合函数的代码,如下:sortAgg_doAggregate_max_0(sortAgg_expr_0_0); sortAgg_doAggregate_avg_0(sortAgg_expr_0_0);
-
$effectiveCodes
组装代码
以上是关于SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起的主要内容,如果未能解决你的问题,请参考以下文章
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起