pyspark对应的scala代码PythonRDD类

Posted vv.past

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pyspark对应的scala代码PythonRDD类相关的知识,希望对你有一定的参考价值。

pyspark jvm端的scala代码PythonRDD

代码版本为 spark 2.2.0

1.PythonRDD.class

这个rdd类型是python能接入spark的关键

//这是一个标准的RDD实现,实现对应的compute,partitioner,getPartitions等方法
//这个PythonRDD就是pyspark里PipelinedRDD里_jrdd属性方法返回的东西
//parent就是PipelinedRDD里传递进来的_prev_jrdd,是最初构建的数据源RDD
private[spark] class PythonRDD(
    parent: RDD[_],  //这个parentRDD是关键,python使用spark的所有数据来源都从这里来的
    func: PythonFunction, //这个是用户实现的python计算逻辑
    preservePartitoning: Boolean)
  extends RDD[Array[Byte]](parent) {

  val bufferSize = conf.getInt("spark.buffer.size", 65536)
  val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)

  override def getPartitions: Array[Partition] = firstParent.partitions

  override val partitioner: Option[Partitioner] = {
    if (preservePartitoning) firstParent.partitioner else None
  }

  val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

  override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
    //调用PythonRunner执行此处任务逻辑
    //这里这个PythonRunner跟spark-submit时执行的PythonRunner不是同一个东西
    val runner = PythonRunner(func, bufferSize, reuse_worker)
    //执行runner的计算逻辑,第一个参数是spark数据源rdd的计算结果
    //firstParent.iterator会触发parent 这个rdd的计算,返回计算结果
    //这里第一个参数的rdd跟pyspark中RDD里的_jrdd是同一个东西
    runner.compute(firstParent.iterator(split, context), split.index, context)
  }
}

2.PythonRunner.class

这个类是rdd内部执行计算时的实体计算类,并不是代码提交时那个启动py4j的PythonRunner

/*
 * 这个类做了三件事
 * 1.启动pyspark.daemon 接收task启动work执行接收到的task
 * 2.启动writerThread 将数据源的计算结果写到pyspark.work中
 * 3.从pyspark.work中拉取执行结果
 * 
 * writerThread写的数据就是pyspark中_jrdd计算出来的结果,也就是数据源rdd的数据
 */
private[spark] class PythonRunner(
    funcs: Seq[ChainedPythonFunctions],
    bufferSize: Int,
    reuse_worker: Boolean,
    isUDF: Boolean,
    argOffsets: Array[Array[Int]])
  extends Logging {

  require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")

  //python执行的环境和命令
  private val envVars = funcs.head.funcs.head.envVars
  private val pythonExec = funcs.head.funcs.head.pythonExec
  private val pythonVer = funcs.head.funcs.head.pythonVer

  private val accumulator = funcs.head.funcs.head.accumulator

  def compute(
      inputIterator: Iterator[_],
      partitionIndex: Int,
      context: TaskContext): Iterator[Array[Byte]] = {
    val startTime = System.currentTimeMillis
    val env = SparkEnv.get
    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
    envVars.put("SPARK_LOCAL_DIRS", localdir) // it‘s also used in monitor thread
    if (reuse_worker) {
      envVars.put("SPARK_REUSE_WORKER", "1")
    }
    
    //创建pyspark 的work进程,底层执行的是pyspark.daemon
    //这个方法保证一次任务只启动一个pyspark.daemon
    //返回结果是跟work通信用的socket
    //具体分析将在其它部分记录
    val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
    @volatile var released = false

    // 创建writerThread,把数据源数据写到socket,发送到pyspark.work
    val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)

    //注册task完成监听,完成后停止writerThread线程
    context.addTaskCompletionListener { context =>
      writerThread.shutdownOnTaskCompletion()
      if (!reuse_worker || !released) {
        try {
          worker.close()
        } catch {
          case e: Exception =>
            logWarning("Failed to close worker socket", e)
        }
      }
    }

    writerThread.start()
    new MonitorThread(env, worker, context).start()

    val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
    // 创建拉取pyspark.work执行结果的迭代器
    val stdoutIterator = new Iterator[Array[Byte]] {
      override def next(): Array[Byte] = {
        val obj = _nextObj
        if (hasNext) {
          _nextObj = read()
        }
        obj
      }

      private def read(): Array[Byte] = {
        if (writerThread.exception.isDefined) {
          throw writerThread.exception.get
        }
        try {
          stream.readInt() match {
            case length if length > 0 =>
              val obj = new Array[Byte](length)
              stream.readFully(obj)
              obj
            case 0 => Array.empty[Byte]
            case SpecialLengths.TIMING_DATA =>
              // Timing data from worker
              val bootTime = stream.readLong()
              val initTime = stream.readLong()
              val finishTime = stream.readLong()
              val boot = bootTime - startTime
              val init = initTime - bootTime
              val finish = finishTime - initTime
              val total = finishTime - startTime
              logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
                init, finish))
              val memoryBytesSpilled = stream.readLong()
              val diskBytesSpilled = stream.readLong()
              context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
              context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
              read()
            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
              // Signals that an exception has been thrown in python
              val exLength = stream.readInt()
              val obj = new Array[Byte](exLength)
              stream.readFully(obj)
              throw new PythonException(new String(obj, StandardCharsets.UTF_8),
                writerThread.exception.getOrElse(null))
            case SpecialLengths.END_OF_DATA_SECTION =>
              // We‘ve finished the data section of the output, but we can still
              // read some accumulator updates:
              val numAccumulatorUpdates = stream.readInt()
              (1 to numAccumulatorUpdates).foreach { _ =>
                val updateLen = stream.readInt()
                val update = new Array[Byte](updateLen)
                stream.readFully(update)
                accumulator.add(update)
              }
              // Check whether the worker is ready to be re-used.
              if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
                if (reuse_worker) {
                  env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
                  released = true
                }
              }
              null
          }
        } catch {

          case e: Exception if context.isInterrupted =>
            logDebug("Exception thrown after task interruption", e)
            throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))

          case e: Exception if env.isStopped =>
            logDebug("Exception thrown after context is stopped", e)
            null  // exit silently

          case e: Exception if writerThread.exception.isDefined =>
            logError("Python worker exited unexpectedly (crashed)", e)
            logError("This may have been caused by a prior exception:", writerThread.exception.get)
            throw writerThread.exception.get

          case eof: EOFException =>
            throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
        }
      }

      var _nextObj = read()

      override def hasNext: Boolean = _nextObj != null
    }
    //返回这个拉取数据结果的迭代器
    new InterruptibleIterator(context, stdoutIterator)
  }

  /**
   * WriterThread 线程的实现代码
   */
  class WriterThread(
      env: SparkEnv,
      worker: Socket,
      inputIterator: Iterator[_],
      partitionIndex: Int,
      context: TaskContext)
    extends Thread(s"stdout writer for $pythonExec") {

    @volatile private var _exception: Exception = null

    private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
    private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))

    setDaemon(true)

    /** Contains the exception thrown while writing the parent iterator to the Python process. */
    def exception: Option[Exception] = Option(_exception)

    /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
    def shutdownOnTaskCompletion() {
      assert(context.isCompleted)
      this.interrupt()
    }

    // 主要逻辑在run里,把数据源rdd的执行结果写进去
    // 把广播变量和环境,以及python的执行逻辑代码写进去
    // 把需要计算的数据源数据写进去
    override def run(): Unit = Utils.logUncaughtExceptions {
      try {
        TaskContext.setTaskContext(context)
        val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
        val dataOut = new DataOutputStream(stream)
        // Partition index
        dataOut.writeInt(partitionIndex)
        // Python version of driver
        PythonRDD.writeUTF(pythonVer, dataOut)
        // Write out the TaskContextInfo
        dataOut.writeInt(context.stageId())
        dataOut.writeInt(context.partitionId())
        dataOut.writeInt(context.attemptNumber())
        dataOut.writeLong(context.taskAttemptId())
        // sparkFilesDir
        PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
        // Python includes (*.zip and *.egg files)
        dataOut.writeInt(pythonIncludes.size)
        for (include <- pythonIncludes) {
          PythonRDD.writeUTF(include, dataOut)
        }
        // Broadcast variables
        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
        val newBids = broadcastVars.map(_.id).toSet
        // number of different broadcasts
        val toRemove = oldBids.diff(newBids)
        val cnt = toRemove.size + newBids.diff(oldBids).size
        dataOut.writeInt(cnt)
        for (bid <- toRemove) {
          // remove the broadcast from worker
          dataOut.writeLong(- bid - 1)  // bid >= 0
          oldBids.remove(bid)
        }
        for (broadcast <- broadcastVars) {
          if (!oldBids.contains(broadcast.id)) {
            // send new broadcast
            dataOut.writeLong(broadcast.id)
            PythonRDD.writeUTF(broadcast.value.path, dataOut)
            oldBids.add(broadcast.id)
          }
        }
        dataOut.flush()
        // Serialized command:
        if (isUDF) {
          dataOut.writeInt(1)
          dataOut.writeInt(funcs.length)
          funcs.zip(argOffsets).foreach { case (chained, offsets) =>
            dataOut.writeInt(offsets.length)
            offsets.foreach { offset =>
              dataOut.writeInt(offset)
            }
            dataOut.writeInt(chained.funcs.length)
            chained.funcs.foreach { f =>
              dataOut.writeInt(f.command.length)
              dataOut.write(f.command)
            }
          }
        } else {
          dataOut.writeInt(0)
          val command = funcs.head.funcs.head.command
          dataOut.writeInt(command.length)
          dataOut.write(command)
        }
        // Data values
        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
        dataOut.flush()
      } catch {
        case e: Exception if context.isCompleted || context.isInterrupted =>
          logDebug("Exception thrown after task completion (likely due to cleanup)", e)
          if (!worker.isClosed) {
            Utils.tryLog(worker.shutdownOutput())
          }

        case e: Exception =>
          // We must avoid throwing exceptions here, because the thread uncaught exception handler
          // will kill the whole executor (see org.apache.spark.executor.Executor).
          _exception = e
          if (!worker.isClosed) {
            Utils.tryLog(worker.shutdownOutput())
          }
      }
    }
  }

  // 监控task是不是还在执行
  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
    extends Thread(s"Worker Monitor for $pythonExec") {

    setDaemon(true)

    override def run() {
      // Kill the worker if it is interrupted, checking until task completion.
      // TODO: This has a race condition if interruption occurs, as completed may still become true.
      while (!context.isInterrupted && !context.isCompleted) {
        Thread.sleep(2000)
      }
      if (!context.isCompleted) {
        try {
          logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
          env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
        } catch {
          case e: Exception =>
            logError("Exception when trying to kill worker", e)
        }
      }
    }
  }
}

以上是关于pyspark对应的scala代码PythonRDD类的主要内容,如果未能解决你的问题,请参考以下文章

从 Scala Spark 代码调用 Pyspark 脚本

在pyspark中展平嵌套的json scala代码

将 pyspark 代码移植到 Spark 2.4.3 的 scala 时出现 SparkException

如何在 AWS EMR 中一起添加 2 个(pyspark、scala)步骤?

如何在 Databricks 的 PySpark 中使用在 Scala 中创建的 DataFrame

如何用 sacala 代码详细说明 pyspark 代码?