原创经验分享(19)Spark中Join实现原理
Posted barneywill
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了原创经验分享(19)Spark中Join实现原理相关的知识,希望对你有一定的参考价值。
spark中join有两种,一种是RDD的join,一种是sql中的join,分别来看:
1 RDD join
org.apache.spark.rdd.PairRDDFunctions
/** * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Performs a hash join across the cluster. */ def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = self.withScope { join(other, defaultPartitioner(self, other)) } /** * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = self.withScope { this.cogroup(other, partitioner).flatMapValues( pair => for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w) ) } /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner) : RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { throw new SparkException("HashPartitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) cg.mapValues { case Array(vs, w1s) => (vs.asInstanceOf[Iterable[V]], w1s.asInstanceOf[Iterable[W]]) } }
join操作会返回CoGroupedRDD,CoGroupedRDD构造参数为rdd数组,即多个需要join的rdd,下面看CoGroupedRDD:
org.apache.spark.rdd.CoGroupedRDD
class CoGroupedRDD[K: ClassTag]( @transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { override def getDependencies: Seq[Dependency[_]] = { rdds.map { rdd: RDD[_] => if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) new ShuffleDependency[K, Any, CoGroupCombiner]( rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer) } } } override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- dependencies.zipWithIndex) dep match { case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked => val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) rddIterators += ((it, depNum)) case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle val it = SparkEnv.get.shuffleManager .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) .read() rddIterators += ((it, depNum)) } val map = createExternalMap(numRdds) for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = { val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { val newCombiner = Array.fill(numRdds)(new CoGroup) newCombiner(value._2) += value._1 newCombiner } val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner = (combiner, value) => { combiner(value._2) += value._1 combiner } val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = (combiner1, combiner2) => { var depNum = 0 while (depNum < numRdds) { combiner1(depNum) ++= combiner2(depNum) depNum += 1 } combiner1 } new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( createCombiner, mergeValue, mergeCombiners) }
CoGroupedRDD首先将rdds逐个转化为dependency,然后将所有的dependency转化为rddIterators,最后通过ExternalAppendOnlyMap来实现合并;
如果rdd需要shuffle,是通过ShuffleManager实现,ShuffleManager实现类为SortShuffleManager,shuffle过程详见:https://www.cnblogs.com/barneywill/p/10158457.html
附:spark中dependency结构,即常说的宽依赖、窄依赖:
org.apache.spark.Dependency
Dependency
NarrowDependency
OneToOneDependency
RangeDependency
ShuffleDependency
区别就是shuffle,不需要shuffle就是NarrowDependency,需要就是ShuffleDependency;
2 sql join
sql中的join有一个选择策略:
org.apache.spark.sql.execution.SparkStrategies.JoinSelection
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBuildRight(joinType) && canBroadcast(right) => Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBuildLeft(joinType) && canBroadcast(left) => Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) // --- SortMergeJoin ------------------------------------------------------------ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil ...
其中conf.preferSortMergeJoin
org.apache.spark.sql.internal.SQLConf
val PREFER_SORTMERGEJOIN = SQLConfigBuilder("spark.sql.join.preferSortMergeJoin") .internal() .doc("When true, prefer sort merge join over shuffle hash join.") .booleanConf .createWithDefault(true)
配置spark.sql.join.preferSortMergeJoin,默认为true,即是否优先使用SortMergeJoin;
可以看到join实现主要有3种,即BroadcastHashJoinExec、ShuffledHashJoinExec和SortMergeJoinExec,优先级为
- 1 如果canBroadcast,则BroadcastHashJoinExec;
- 2 如果spark.sql.join.preferSortMergeJoin=false,则ShuffledHashJoinExec;
- 3 否则为SortMergeJoinExec;
其中BroadcastHashJoinExec和ShuffledHashJoinExec都会用到HashJoin,先看HashJoin:
2.1 HashJoin
org.apache.spark.sql.execution.joins.HashJoin
protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, numOutputRows: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike => innerJoin(streamedIter, hashed) case LeftOuter | RightOuter => outerJoin(streamedIter, hashed) case LeftSemi => semiJoin(streamedIter, hashed) case LeftAnti => antiJoin(streamedIter, hashed) case j: ExistenceJoin => existenceJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") } val resultProj = createResultProjection joinedIter.map { r => numOutputRows += 1 resultProj(r) } } private def innerJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { val joinRow = new JoinedRow val joinKeys = streamSideKeyGenerator() streamIter.flatMap { srow => joinRow.withLeft(srow) val matches = hashedRelation.get(joinKeys(srow)) if (matches != null) { matches.map(joinRow.withRight(_)).filter(boundCondition) } else { Seq.empty } } }
这里只贴出内关联,即innerJoin,代码比较简单,注意这里是内存操作,会在单个partition内部进行;
2.2 BroadcastHashJoinExec
org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) join(streamedIter, hashed, numOutputRows) } }
这里会将buildPlan广播出去,然后在streamedPlan上通过mapPartitions在1个分区内部进行join,join方法见HashJoin;
2.3 ShuffledHashJoinExec
org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows) } }
join过程为先将两个rdd(streamedPlan和buildPlan)进行zipPartitions,然后在1个partition内部join,join方法见HashJoin;
2.4 SortMergeJoinExec
org.apache.spark.sql.execution.joins.SortMergeJoinExec
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } } // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) joinType match { case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ private[this] var currentMatchIdx: Int = -1 private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter) ) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow currentMatchIdx = 0 } override def advanceNext(): Boolean = { while (currentMatchIdx >= 0) { if (currentMatchIdx == currentRightMatches.length) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow currentMatchIdx = 0 } else { currentRightMatches = null currentLeftRow = null currentMatchIdx = -1 return false } } joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) currentMatchIdx += 1 if (boundCondition(joinRow)) { numOutputRows += 1 return true } } false } override def getRow: InternalRow = resultProj(joinRow) }.toScala ...
和ShuffledHashJoinExec一样,同样先zipPartitions,然后在1个partition内部根据joinType返回不同的RowIterator实现类,上边代码包含内关联实现,大部分工作通过SortMergeJoinScanner实现
org.apache.spark.sql.execution.joins.SortMergeJoinScanner
final def findNextInnerJoinRows(): Boolean = { while (advancedStreamed() && streamedRowKey.anyNull) { // Advance the streamed side of the join until we find the next row whose join key contains // no nulls or we hit the end of the streamed iterator. } if (streamedRow == null) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { // The new streamed row has the same join key as the previous row, so return the same matches. true } else if (bufferedRow == null) { // The streamed row‘s join key does not match the current batch of buffered rows and there are // no more rows to read from the buffered iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // Advance both the streamed and buffered iterators to find the next pair of matching rows. var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) do { if (streamedRowKey.anyNull) { advancedStreamed() } else { assert(!bufferedRowKey.anyNull) comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() else if (comp < 0) advancedStreamed() } } while (streamedRow != null && bufferedRow != null && comp != 0) if (streamedRow == null || bufferedRow == null) { // We have either hit the end of one of the iterators, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // The streamed row‘s join key matches the current buffered row‘s join, so walk through the // buffered iterator to buffer the rest of the matching rows. assert(comp == 0) bufferMatchingRows() true } } } /** * Advance the streamed iterator and compute the new row‘s join key. * @return true if the streamed iterator returned a row and false otherwise. */ private def advancedStreamed(): Boolean = { if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) true } else { streamedRow = null streamedRowKey = null false } } /** * Advance the buffered iterator until we find a row with join key that does not contain nulls. * @return true if the buffered iterator returned a row and false otherwise. */ private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { var foundRow: Boolean = false while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) foundRow = !bufferedRowKey.anyNull } if (!foundRow) { bufferedRow = null bufferedRowKey = null false } else { true } } /** * Called when the streamed and buffered join keys match in order to buffer the matching rows. */ private def bufferMatchingRows(): Unit = { assert(streamedRowKey != null) assert(!streamedRowKey.anyNull) assert(bufferedRowKey != null) assert(!bufferedRowKey.anyNull) assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) // This join key may have been produced by a mutable projection, so we need to make a copy: matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them advancedBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) }
可以看到过程和二路归并排序Binary Merge Sort差不多;
附:RowIterator是一个抽象类,本质是一个接口,是一个常见的Iterator定义,如下:
org.apache.spark.sql.execution.RowIterator
abstract class RowIterator { /** * Advance this iterator by a single row. Returns `false` if this iterator has no more rows * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling * [[getRow]]. */ def advanceNext(): Boolean /** * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this * method after [[advanceNext()]] has returned `false`. */ def getRow: InternalRow /** * Convert this RowIterator into a [[scala.collection.Iterator]]. */ def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) }
以上是关于原创经验分享(19)Spark中Join实现原理的主要内容,如果未能解决你的问题,请参考以下文章
原创问题定位分享(16)spark写数据到hive外部表报错ClassCastException: org.apache.hadoop.hive.hbase.HiveHBaseTableOutpu(代
原创大叔经验分享(39)spark cache unpersist级联操作
原创经验分享(10)Could not transfer artifact org.apache.maven:maven. from/to central. Received fatal aler(代