SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
Posted 鸿乃江边鸟
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起相关的知识,希望对你有一定的参考价值。
背景
本文基于 SPARK 3.3.0
从一个unit test来探究SPARK Codegen的逻辑,
test("SortAggregate should be included in WholeStageCodegen")
val df = spark.range(10).agg(max(col("id")), avg(col("id")))
withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true")
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
assert(df.collect() === Array(Row(9, 4.5)))
该sql形成的执行计划第二部分的全代码生成部分如下:
WholeStageCodegen
*(2) SortAggregate(key=[], functions=[max(id#0L), avg(id#0L)], output=[max(id)#5L, avg(id)#6])
InputAdapter
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#13]
分析
第二阶段wholeStageCodegen
第二阶段的代码生成涉及到SortAggregateExec和ShuffleExchangeExec以及InputAdapter的produce和consume方法,这里一一来分析:
第二阶段wholeStageCodegen数据流如下:
WholeStageCodegenExec SortAggregateExec(Final) InputAdapter ShuffleExchangeExec
====================================================================================
-> execute()
|
doExecute() ---------> inputRDDs() -----------------> inputRDDs() -------> execute()
| |
doCodeGen() doExecute()
| |
+-----------------> produce() ShuffledRowRDD
|
doProduce()
|
doProduceWithoutKeys() -------> produce()
|
doProduce()
|
doConsume() <------------------- consume()
|
doConsumeWithoutKeys()
|并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
doConsume() <-------- consume()
SortAggregateExec(Final) 的doProduce
这里只列出和SortAggregateExec(Partial)的不同的部分:
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete))
// evaluate aggregate results
ctx.currentVars = flatBufVars
val aggResults = bindReferences(
functions.map(_.evaluateExpression),
aggregateBufferAttributes).map(_.genCode(ctx))
val evaluateAggResults = evaluateVariables(aggResults)
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
(resultVars,
s"""
|$evaluateAggResults
|$evaluateVariables(resultVars)
""".stripMargin)
- 因为我们这里是Final部分,所以我们的数据流和Partial是不同的
- ctx.currentVars = flatBufVars
赋值currentVars
为当前buffer变量,便于下面进行数据绑定,该buffer变量是全局变量 - val aggResults = bindReferences
functions.map(_.evaluateExpression)
这是对最终输出结果的计算,对于SUM
来说是Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
,生成的代码如下:
boolean sortAgg_isNull_6 = sortAgg_bufIsNull_2; double sortAgg_value_6 = -1.0; if (!sortAgg_bufIsNull_2) sortAgg_value_6 = (double) sortAgg_bufValue_2; boolean sortAgg_isNull_4 = false; double sortAgg_value_4 = -1.0; if (sortAgg_isNull_6 || sortAgg_value_6 == 0) sortAgg_isNull_4 = true; else if (sortAgg_bufIsNull_1) sortAgg_isNull_4 = true; else sortAgg_value_4 = (double)(sortAgg_bufValue_1 / sortAgg_value_6);
aggregateBufferAttributes
聚合函数的buffer属性值sum :: count :: Nil
这样在绑定数据的变量数据的时候和currentVars
是一一对应的
- val evaluateAggResults = evaluateVariables(aggResults)
对聚合的结果进行最终的计算 - ctx.currentVars = aggResults
把最终结果的变量赋值给currentVars
,便于后面的数据绑定 - val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
这一步是把聚合结果的变量绑定到聚合表达式中,
其中resultExpressions
为List( avg(id#0L)#3 AS avg(id)#6)
(这里我们只考虑AVG)
aggregateAttributes
是resultExpression
的AttributeReference
的一种表达,便于在BoundReference的时候进行映射绑定
对应的ExprCode为ExprCode(,sortAgg_isNull_4,sortAgg_value_4))
InputAdaptor的 doProduce
InputAdaptor
的主要作用是承上启下,用来适配不支持Codegen的物理计划,sql如下:
override def doProduce(ctx: CodegenContext): String =
// Inline mutable state since an InputRDDCodegen is used once in a task for WholeStageCodegen
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
forceInline = true)
val row = ctx.freshName("row")
val outputVars = if (createUnsafeProjection)
// creating the vars will make the parent consume add an unsafe projection.
ctx.INPUT_ROW = row
ctx.currentVars = null
output.zipWithIndex.map case (a, i) =>
BoundReference(i, a.dataType, a.nullable).genCode(ctx)
else
null
val updateNumOutputRowsMetrics = if (metrics.contains("numOutputRows"))
val numOutputRows = metricTerm(ctx, "numOutputRows")
s"$numOutputRows.add(1);"
else
""
s"""
| while ($limitNotReachedCond $input.hasNext())
| InternalRow $row = (InternalRow) $input.next();
| $updateNumOutputRowsMetrics
| $consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim
| $shouldStopCheckCode
|
""".stripMargin
- val input = ctx.addMutableState(“scala.collection.Iterator”, “input”, v => s"$v = inputs[0];"
定义一个input变量用来接受sortaggregate(partial)的输出的InteralRow(unsafeRow),对应的初始化方法会在init
方法中调用 - val row = ctx.freshName(“row”)
定义一个临时变量用来接受input中的unsafe类型的InteralRow,便于进行迭代操作 - val outputVars = if (createUnsafeProjection)
对于InputAdaptor
来说createUnsafeProjection
是false
, 所以这块返回的是null
- val updateNumOutputRowsMetrics =
因为metrics
不满足条件,所以这里也是返回空字符串 - 代码组装
对输入的每一行数据进行迭代操作, 之后再调用s""" | while ($limitNotReachedCond $input.hasNext()) | InternalRow $row = (InternalRow) $input.next(); | $updateNumOutputRowsMetrics | $consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim | $shouldStopCheckCode | """.stripMargin
consume
方法,
注意: 这里的consume
传入的是row,是InteralRow类型,而不是在RangeExec
中的Long
类型的变量
InputAdaptor的 consume
我们这里只说明和之前不一样的部分,对应的sql如下:
final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String =
注意这里的参数 outputVars
为null
row
为InteralRow
类型的变量
- val inputVarsCandidate =
val inputVarsCandidate =
if (outputVars != null)
assert(outputVars.length == output.length)
// outputVars will be used to generate the code for UnsafeRow, so we should copy them
outputVars.map(_.copy())
else
assert(row != null, "outputVars and row cannot both be null.")
ctx.currentVars = null
ctx.INPUT_ROW = row
output.zipWithIndex.map case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
这里的数据流向了 else
:
- ctx.INPUT_ROW = row
设置当前的INPUT_ROW
为row
BoundReference
的doGenCode
方法也是走向了另一个分支:
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
val javaType = JavaCode.javaType(dataType)
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable)
ev.copy(code =
code"""
|boolean $ev.isNull = $ctx.INPUT_ROW.isNullAt($ordinal);
|$javaType $ev.value = $ev.isNull ?
| $CodeGenerator.defaultValue(dataType) : ($value);
""".stripMargin)
else
ev.copy(code = code"$javaType $ev.value = $value;", isNull = FalseLiteral)
- 分析
-
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType,ordinal.toString)
根据数据类型的不同,调用UnsafeRow
的不同方法 -
if (nullable)
因为AttributeReference("sum", sumDataType)()
和AttributeReference("count", LongType)()
表达式nullable
为TRUE
,所以生成的代码为:boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); long inputadapter_value_0 = inputadapter_isNull_0 ? -1L : (inputadapter_row_0.getLong(0)); boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1); double inputadapter_value_1 = inputadapter_isNull_1 ? -1.0 : (inputadapter_row_0.getDouble(1)); boolean inputadapter_isNull_2 = inputadapter_row_0.isNullAt(2); long inputadapter_value_2 = inputadapter_isNull_2 ? -1L : (inputadapter_row_0.getLong(2));
-
- constructDoConsumeFunction方法中inputVarsInFunc
这里会多一个名为inputadapter_row_0的InternalRow类型的实参
以上是关于SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起的主要内容,如果未能解决你的问题,请参考以下文章
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起
SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起