9. spark源代码分析(基于yarn cluster模式)- Task执行,Reduce端读取shuffle数据文件
Posted Leo Han
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了9. spark源代码分析(基于yarn cluster模式)- Task执行,Reduce端读取shuffle数据文件相关的知识,希望对你有一定的参考价值。
本系列基于spark-2.4.6
通过上一节的分析,我们了解了Spark中ShuflleMapTask中Map端数据的写入流程,这个章节我们分析下Reduce端是如何读取数据的。
在ShulleMapTask.runTask
中,有这么一个步骤:
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
其中rdd.iterator
:
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
}
最后都会调用RDD
·如下方法:
def compute(split: Partition, context: TaskContext): Iterator[T]
而RDD有多重实现,我们看看RDD中groupBy
,返回的是一个ShuffledRDD
,而ShuffledRDD
中对应的compute
实现如下:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
这里的read实现在BlockStoreShuffleReader
中:
override def read(): Iterator[Product2[K, C]] = {
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
val resultIter = dep.keyOrdering match {
val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
// Use completion callback to stop sorter if task was finished/cancelled.
context.addTaskCompletionListener[Unit](_ => {
sorter.stop()
})
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
这里首先需要注意下mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
,这是去master获取当前节点需要获取的shuffle数据。
重要的逻辑在ShuffleBlockFetcherIterator
,另外这里需要注意几个参数:
spark.reducer.maxSizeInFlight
spark.reducer.maxReqsInFlight
ShuffleBlockFetcherIterator
在生成后立马执行初始化方法initialize
:
private[this] def initialize(): Unit = {
context.addTaskCompletionListener[Unit](_ => cleanup())
val remoteRequests = splitLocalRemoteBlocks()
fetchRequests ++= Utils.randomize(remoteRequests)
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
fetchLocalBlocks()
}
首先通过splitLocalRemoteBlocks
,划分需要拉取哪些数据:
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address.executorId == blockManager.blockManagerId.executorId) {
blockInfos.find(_._2 <= 0) match {
case Some((blockId, size)) if size < 0 =>
throw new BlockException(blockId, "Negative block size " + size)
case Some((blockId, size)) if size == 0 =>
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
case None => // do nothing.
}
localBlocks ++= blockInfos.map(_._1)
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
} else if (size == 0) {
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
} else {
curBlocks += ((blockId, size))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
}
if (curRequestSize >= targetRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
remoteRequests
}
可以看到这里会区分需要拉取的数据是本地数据还是远程数据(这里数据用Block表示),如果是本地数据则会放入把数据对应的BlockId放入到localBlocks
集合中。如果是远端的数据,这里是按照一个节点一个节点来遍历节点下的所有数据,是按照节点来拉取节点上的所有数据。这里会判断当前节点遍历的Block,如果遍历到当前Block,所有Block的大小 >= targetRequestSize 或者Block的个数大于maxBlocksInFlightPerAddress
的时候,则会将已经遍历当前节点的Block放到一次请求中去拉取数据,这里的targetRequestSize
是前面说的"spark.reducer.maxSizeInFlight/5
这里除以5是为了增加并行度maxBlocksInFlightPerAddress
则是每次请求一个节点额数据最多请求多少个Block,默认情况下这个是Int.MAX.到这里就将本地和远端需要拉取的数据分好了,然后会通过fetchUpToMaxBytes
获取对应节点上的Block的信息,然后拉取Block数据,
发送拉取数据请求sendRequest
,这里需要注意有一个处理逻辑:
if (req.size > maxReqSizeShuffleToMem) {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, this)
} else {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, null)
}
这里会判断拉取数据的大小,如果待拉取的数据大小> maxReqSizeShuffleToMem ,那么会将数据写入到本地磁盘
,这里的maxReqSizeShuffleToMem
通过spark.maxRemoteBlockSizeFetchToMem
来配置,默认是Int.MaxValue - 512 字节
最终会调用NettyBlockTransferService.fetchBlocks
:
override def fetchBlocks(
host: String,
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener,
tempFileManager: DownloadFileManager): Unit = {
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
transportConf, tempFileManager).start()
}
}
val maxRetries = transportConf.maxIORetries()
if (maxRetries > 0) {
new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
} else {
blockFetchStarter.createAndStart(blockIds, listener)
}
} catch {
}
可以看到最后启动了OneForOneBlockFetcher
:
public void start() {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
try {
streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
for (int i = 0; i < streamHandle.numChunks; i++) {
if (downloadFileManager != null) {
client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
new DownloadCallback(i));
} else {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
}
} catch (Exception e) {
failRemainingBlocks(blockIds, e);
}
}
@Override
public void onFailure(Throwable e) {
failRemainingBlocks(blockIds, e);
}
});
}
这里先是给要拉取数据的节点发送了一个OpenBlocks
信息,如果成功后,则会调用TransportClient获取对饮的数据,这里会判断downloadFileManager
是否为空,就是上面说的这个条件如果待拉取的数据大小> maxReqSizeShuffleToMem
,如果满足则要写文件downloadFileManager
不为空,否则直接写内存。
- 写文件方式最后底层是发送了一个
StreamRequest
请求 - 写内存方式发送了一个
ChunkFetchRequest
请求
同时,当节点返回成功之后,会通过对应Callback进行处理:
public void stream(String streamId, StreamCallback callback) {
StdChannelListener listener = new StdChannelListener(streamId) {
void handleFailure(String errorMsg, Throwable cause) throws Exception {
callback.onFailure(streamId, new IOException(errorMsg, cause));
}
};
synchronized (this) {
handler.addStreamCallback(streamId, callback);
channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener);
}
}
public void fetchChunk(
long streamId,
int chunkIndex,
ChunkReceivedCallback callback) {
StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
StdChannelListener listener = new StdChannelListener(streamChunkId) {
void handleFailure(String errorMsg, Throwable cause) {
handler.removeFetchRequest(streamChunkId);
callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
}
};
handler.addFetchRequest(streamChunkId, callback);
channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener);
}
到这里reduce单已经将请求发送出去,接下来我们看下对应节点daunt怎么相应对饮请求
首先是对应OpenBlocks
请求,最后在NettyBlockRpcServer
进行处理:
override def receive(
client: TransportClient,
rpcMessage: ByteBuffer,
responseContext: RpcResponseCallback): Unit = {
val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
logTrace(s"Received request: $message")
message match {
case openBlocks: OpenBlocks =>
val blocksNum = openBlocks.blockIds.length
val blocks = for (i <- (0 until blocksNum).view)
yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
client.getChannel)
logTrace(s"Registered streamId $streamId with $blocksNum buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)
case uploadBlock: UploadBlock =>
// StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
val (level: StorageLevel, classTag: ClassTag[_]) = {
serializer
.newInstance()
.deserialize(ByteBuffer.wrap(uploadBlock.metadata))
.asInstanceOf[(StorageLevel, ClassTag[_])]
}
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
val blockId = BlockId(uploadBlock.blockId)
logDebug(s"Receiving replicated block $blockId with level ${level} " +
s"from ${client.getSocketAddress}")
blockManager.putBlockData(blockId, data, level, classTag)
responseContext.onSuccess(ByteBuffer.allocate(0))
}
}
这里会对每个请求注册一个StreamId和对应的StreamState
,返回个拉取端一个StreamHandle
信息,包含了StreamId和Block的个数。在开始的时候会把每个要拉取的Block的数据读取出来通过getBlockData实现
:
override def getBlockData(blockId: BlockId): ManagedBuffer = {
if (blockId.isShuffle) {
shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
getLocalBytes(blockId) match {
case Some(blockData) =>
new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true)
case None =>
reportBlockStatus(blockId, BlockStatus.empty)
throw new BlockNotFoundException(blockId.toString)
}
}
}
这里我们是reduce读取,blockId.isShuffle=true
val shuffleBlockResolver = shuffleManager.shuffleBlockResolver
val buf = new ChunkedByteBuffer( shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
Some(new ByteBufferBlockData(buf, true))
最后通过IndexShuffleBlockResolver
来进行读取,这也就是上一节我们说的,Map端的写入同时会生成一个索引文件,这里会通过所以文件获取对应数据的信息:
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
val channel = Files.newByteChannel(indexFile.toPath)
channel.position(blockId.reduceId * 8L)
val in = new DataInputStream(Channels.newInputStream(channel))
try {
val offset = in.readLong()
val nextOffset = in.readLong()
val actualPosition = channel.position(以上是关于9. spark源代码分析(基于yarn cluster模式)- Task执行,Reduce端读取shuffle数据文件的主要内容,如果未能解决你的问题,请参考以下文章
11. spark源代码分析(基于yarn cluster模式)- 聊聊Stage和Task
3. spark-2.4.6源码分析(基于yarn cluster模式)-YARN ApplicationMaster启动
2. spark-2.4.6源码分析(基于yarn cluster模式)-YARN client启动,提交ApplicationMaster
7. spark源码分析(基于yarn cluster模式)- Task划分提交
3. spark-2.4.6源码分析(基于yarn cluster模式)-YARN contaienr启动-CoarseGrainedExecutorBackend