spark CTAS nuion all (union all的个数很多)导致超过spark.driver.maxResultSize配置(2G)

Posted 鸿乃江边鸟

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了spark CTAS nuion all (union all的个数很多)导致超过spark.driver.maxResultSize配置(2G)相关的知识,希望对你有一定的参考价值。

背景

该sql运行在spark版本 3.1.2下的thrift server下

现象

在运行包含多个union 的spark sql的时候报错(该sql包含了50多个uinon,且每个union字查询中会包含join操作),其中union中子查询sql类似如下:

SELECT  a1.order_no
                    ,a1.need_column
                    ,a1.join_id
            FROM    temp.actul_a a1 
            join temp.actul_a a2 on a1.join_id = a2.join_id and a2.need_column = 'we need it' 
            WHERE a1.need_column ='others needs it'

运行对应的sql,报错如下:

Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Total size of serialized results of 22460 tasks (2.0 GiB) is bigger than spark.driver.maxResultSize (2.0 GiB)
 at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2258)
 at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2207)
 at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2206)
 at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
 at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
 at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
 at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2206)
 at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1079)
 at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1079)
 at scala.Option.foreach(Option.scala:407)
 at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1079)
 at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2445)
 at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2387)
 at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2376)
 at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
 at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:868)
 at org.apache.spark.SparkContext.runJob(SparkContext.scala:2196)
 at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:200)
 ... 40 more (state=,code=0)

问题复盘以及解决

问题复盘

  • 参与union操作的所有RDD的任务组成了union操作的所有任务
  • 每个写文件的任务在resultTask执行完之后把文件的元数据(包括,文件个数,文件大小,文件行数)回传给driver
  • driver会计算一个taskSet里面的所有的resultTask任务结果的元数据(每完成一个task计算一下)是超过spark.driver.maxResultSize值,如果超过就直接报错

解决

把分区合并的初始分区减少,目前是1000。
set spark.sql.adaptive.coalescePartitions.initialPartitionNum=200;

分析以及解释

  • 报错的流程复盘(代码级别)
    对对应的sql进行explain,由于代码过长,自己复制粘贴即可,得到如下物理计划:
+----------------------------------------------------+
|                        plan                        |
+----------------------------------------------------+
| == Physical Plan ==
Execute OptimizedCreateHiveTableAsSelectCommand [Database: default, TableName: actul_a, InsertIntoHadoopFsRelationCommand]
+- AdaptiveSparkPlan isFinalPlan=false
   +- Union
      :...

其中union的操作的子节点有50多个,着重观察OptimizedCreateHiveTableAsSelectCommand 和unionExec两个物理计划,
在分析这两个物理计划之前,先分析一下报错的地方的源码,直接搜索可以找到

TaskSetManager.scala

 def canFetchMoreResults(size: Long): Boolean = sched.synchronized 
    totalResultSize += size
    calculatedTasks += 1
    if (!isShuffleMapTasks && maxResultSize > 0 && totalResultSize > maxResultSize) 
      val msg = s"Total size of serialized results of $calculatedTasks tasks " +
        s"($Utils.bytesToString(totalResultSize)) is bigger than $config.MAX_RESULT_SIZE.key " +
        s"($Utils.bytesToString(maxResultSize))"
      logError(msg)
      abort(msg)
      false
     else 
      true
    
  

而canFetchMoreResults这个方法最终会被TaskSchedulerImpl.scala的statusUpdate方法调用:

def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer): Unit = 
     ...
            if (TaskState.isFinished(state)) 
              cleanupTaskState(tid)
              taskSet.removeRunningTask(tid)
              if (state == TaskState.FINISHED) 
                taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
               else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) 
                taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
              
            
    ...

注意到这里所有的tid都是在同一个taskSet中,而taskSet是从资源池在获取的,如下:

    val sortedTaskSets = rootPool.getSortedTaskSetQueue

那资源池的taskset是从哪里进去的?在submitTasks方法中:

override def submitTasks(taskSet: TaskSet): Unit = 
    val tasks = taskSet.tasks
    logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks "
      + "resource profile " + taskSet.resourceProfileId)
    this.synchronized 
      val manager = createTaskSetManager(taskSet, maxTaskFailures)
      val stage = taskSet.stageId
      val stageTaskSets =
        taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
      stageTaskSets(taskSet.stageAttemptId) = manager
      schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

而这里的taskSet由DAGSchduler调用:

 taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties,
        stage.resourceProfileId))

而这里的tasks参数是由:partitionsToCompute来的:

val tasks: Seq[Task[_]] = try 
      val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
      stage match 
        case stage: ShuffleMapStage =>
          stage.pendingPartitions.clear()
          partitionsToCompute.map  id =>
            val locs = taskIdToLocations(id)
            val part = partitions(id)
            stage.pendingPartitions += id
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
              taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
              Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
          

        case stage: ResultStage =>
          partitionsToCompute.map  id =>
            val p: Int = stage.partitions(id)
            val part = partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(stage.id, stage.latestInfo.attemptNumber,
              taskBinary, part, locs, id, properties, serializedTaskMetrics,
              Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
              stage.rdd.isBarrier())
          
      

而最终partitionsToCompute调用了rdd.partitions方法, 最终调用getPartitions方法,这个方法会在unionRDD有体现,
,我们再来看UnionExec:

case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan 
  ···

  protected override def doExecute(): RDD[InternalRow] =
    sparkContext.union(children.map(_.execute()))

sparkContext.union会返回UnionRDD,而uninRDD对应的getPartitions方法就是所有的rdd的分区之和,再结合之前分析的taskset,可以得到
unionRDD的task数量就是所有参与union操作的RDD的task的个数。

那为啥会出现超出spark.driver.maxResultSize (2.0 GiB)的问题呢?
再看OptimizedCreateHiveTableAsSelectCommand 执行计划,该计划最终会调用InsertIntoHadoopFsRelationCommand的Run方法:

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = 
   ···
      val updatedPartitionPaths =
        FileFormatWriter.write(
          sparkSession = sparkSession,
          plan = child,
          fileFormat = fileFormat,
          committer = committer,
          outputSpec = FileFormatWriter.OutputSpec(
            committerOutputPath.toString, customPartitionLocations, outputColumns),
          hadoopConf = hadoopConf,
          partitionColumns = partitionColumns,
          bucketSpec = bucketSpec,
          statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
          options = options)
···

而FileFormatWriter.write最终调用spark的runJob方法,以及每个task会返回WriteTaskResult(包括了写入的文件的分区,大小,个数,以及数据行):

 val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length)
      sparkSession.sparkContext.runJob(
        rddWithNonEmptyPartitions,
        (taskContext: TaskContext, iter: Iterator[InternalRow]) => 
          executeTask(
            description = description,
            jobIdInstant = jobIdInstant,
            sparkStageId = taskContext.stageId(),
            sparkPartitionId = taskContext.partitionId(),
            sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE,
            committer,
            iterator = iter)
        ,
        rddWithNonEmptyPartitions.partitions.indices,
        (index, res: WriteTaskResult) => 
          committer.onTaskCommit(res.commitMsg)
          ret(index) = res
        )

executeTask的方法才是真正执行write任务的逻辑:

···
try 
      Utils.tryWithSafeFinallyAndFailureCallbacks(block = 
        // Execute the task to write rows out and commit the task.
        while (iterator.hasNext) 
          dataWriter.write(iterator.next())
        
        dataWriter.commit()
      )(catchBlock = 
···

dataWriter.write方法真正的写数据,dataWrite.commit方法返回对应的Task执行时候产生的WriteTaskResult信息。
对应到ResultTask就是func(context, rdd.iterator(partition, context)):

ResultTask.scala

 override def runTask(context: TaskContext): U = 
    // Deserialize the RDD and the func using the broadcast variables.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTimeNs = System.nanoTime()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) 
      threadMXBean.getCurrentThreadCpuTime
     else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTimeNs = System.nanoTime() - deserializeStartTimeNs
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) 
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
     else 0L

    func(context, rdd.iterator(partition, context))
  

而该方法最终被Executor.scala的Run方法调用:

val value = Utils.tryWithSafeFinally 
          val res = task.run(
            taskAttemptId = taskId,
            attemptNumber = taskDescription.attemptNumber,
            metricsSystem = env.metricsSystem,
            resources = taskDescription.resources,
            plugins = plugins)
          threwException = false
          res

而结果最终经过一系列的判断,最终回传给CoarseGrainedSchedulerBackend:

 val serializedResult: ByteBuffer = 
          if (maxResultSize > 0 && resultSize > maxResultSize) 
            logWarning(s"Finished $taskName. Result is larger than maxResultSize " +
              s"($Utils.bytesToString(resultSize) > $Utils.bytesToString(maxResultSize)), " +
              s"dropping it.")
            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
           else if (resultSize > maxDirectResultSize) 
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId,
              new ChunkedByteBuffer(serializedDirectResult.duplicate()),
              StorageLevel.MEMORY_AND_DISK_SER)
            logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)")
            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
           else 
            logInfo(s"Finished $taskName. $resultSize bytes result sent to driver")
            serializedDirectResult
          
        

        executorSource.SUCCEEDED_TASKS.inc(1L)
        setTaskFinishedAndClearInterruptStatus()
        plugins.foreach(_.onTaskSucceeded())
        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

最终传给TaskSchedulerImpl的statusUpdate方法,之后再到canFetchMoreResults,从而在代码上形成了闭环。

但是为什么导致这么多的task数量呢?原因在于spark.sql.adaptive.coalescePartitions.initialPartitionNum的设置
我们设置是1000,导致在AQE和开启分区合并的前提下,会导致主要是涉及shuffle的操作的时候,shuffle完之后的分区数就是1000,而在开启了localShuffleReader的前提下,
该分区的也不会大量的减少, 具体见SQLConf.scala:

def numShufflePartitions: Int = 
    if (adaptiveExecutionEnabled && coalesceShufflePartitionsEnabled) 
      getConf(COALESCE_PARTITIONS_INITIAL_PARTITION_NUM).getOrElse(defaultNumShufflePartitions)
     else 
      defaultNumShufflePartitions
    
  

至于为什么开启了localShuffleReader的前提下没有大量减少,下次再聊.

以上是关于spark CTAS nuion all (union all的个数很多)导致超过spark.driver.maxResultSize配置(2G)的主要内容,如果未能解决你的问题,请参考以下文章

Spark CTAS 上的 Hive 使用 Straight SELECT 失败,但使用 SELECT GROUP BY 成功

Hudi基础 -- Spark SQL DDL

多个 RDD 的 Spark union

DML语句 -- 联合查询

union all是啥意思?

oracle union all啥意思?