[SPARK][CORE] 面试问题之 Shuffle reader 的细枝末节 (下)

2022-06-08 18:14:55 浏览数 (1)

在Spark中shuffleWriter有三种实现,分别是bypassMergeSortShuffleWriter, UnsafeShuffleWriter和SortShuffleWriter。但是shuffleReader却只有一种实现BlockStoreShuffleReader

从上一讲中可以知道,这时Spark已经获取到了shuffle元数据包括每个mapId和其location信息,并将其传递给BlockStoreShuffleReader类。接下来我们来详细分析下BlockStoreShuffleReader的实现。

代码语言:javascript复制
// BlockStoreShuffleReader
override def read(): Iterator[Product2[K, C]] = {
  // [1] 初始化ShuffleBlockFetcherIterator,负责从executor中获取 shuffle 块
  val wrappedStreams = new ShuffleBlockFetcherIterator(
    context,
    blockManager.blockStoreClient,
    blockManager,
    mapOutputTracker,
    blocksByAddress,
    ...
    readMetrics,
    fetchContinuousBlocksInBatch).toCompletionIterator

  val serializerInstance = dep.serializer.newInstance()

  // [2] 将shuffle 块反序列化为record迭代器
  // Create a key/value iterator for each stream
  val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
    // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
    // NextIterator. The NextIterator makes sure that close() is called on the
    // underlying InputStream when all records have been read.
    serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
  }

  // Update the context task metrics for each record read.
  val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
    recordIter.map { record =>
      readMetrics.incRecordsRead(1)
      record
    },
    context.taskMetrics().mergeShuffleReadMetrics())

  // An interruptible iterator must be used here in order to support task cancellation
  val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
   // [3] reduce端聚合数据:如果map端已经聚合过了,则对读取到的聚合结果进行聚合。如果map端没有聚合,则针对未合并的<k,v>进行聚合。
  val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
    if (dep.mapSideCombine) {
      // We are reading values that are already combined
      val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
    } else {
      // We don't know the value type, but also don't care -- the dependency *should*
      // have made sure its compatible w/ this aggregator, which will convert the value
      // type to the combined type C
      val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
    }
  } else {
    interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
  }
  // [4] reduce端排序数据:如果需要对key排序,则进行排序。基于sort的shuffle实现过程中,默认只是按照partitionId排序。在每一个partition内部并没有排序,因此添加了keyOrdering变量,提供是否需要对分区内部的key排序
  // Sort the output if there is a sort ordering defined.
  val resultIter: Iterator[Product2[K, C]] =dep.keyOrdering match {
    caseSome(keyOrd: Ordering[K]) =>
      // Create an ExternalSorter to sort the data.
      val sorter =
        new ExternalSorter[K, C, C](context, ordering =Some(keyOrd), serializer =dep.serializer)
      sorter.insertAllAndUpdateMetrics(aggregatedIter)
    case None =>
      aggregatedIter
  }

  // [5] 返回结果集迭代器
  resultIter match {
    case _: InterruptibleIterator[Product2[K, C]] => resultIter
    case _ =>
      // Use another interruptible iterator here to support task cancellation as aggregator
      // or(and) sorter may have consumed previous interruptible iterator.
      new InterruptibleIterator[Product2[K, C]](context, resultIter)
  }
}

从上面可见,在BlockStoreShuffleReader.read()读取数据有五步:

  • [1] 初始化ShuffleBlockFetcherIterator,负责从executor中获取 shuffle 块
  • [2] 将shuffle 块反序列化为record迭代器
  • [3] reduce端聚合数据:如果map端已经聚合过了,则对读取到的聚合结果进行聚合。如果map端没有聚合,则针对未合并的<k,v>进行聚合。
  • [4] reduce端排序数据:如果需要对key排序,则进行排序。基于sort的shuffle实现过程中,默认只是按照partitionId排序。在每一个partition内部并没有排序,因此添加了keyOrdering变量,提供是否需要对分区内部的key排序
  • [5] 返回结果集迭代器

下面我们详细分析下ShuffleBlockFetcherIterator是如何进行fetch数据的

ShuffleBlockFetcherIterator是如何进行fetch数据的?

当shuffle reader创建 ShuffleBlockFetcherIterator 的实例时,迭代器调用在其initialize()方法。

代码语言:javascript复制
// ShuffleBlockFetcherIterator
private[this] def initialize(): Unit = {
  // Add a task completion callback (called in both success case and failure case) to cleanup.
  context.addTaskCompletionListener(onCompleteCallback)
  // Local blocks to fetch, excluding zero-sized blocks.
  val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
  val hostLocalBlocksByExecutor =
    mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
  val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
  // [1] 划分数据源的请求:本地、主机本地和远程块
  // Partition blocks by the different fetch modes: local, host-local, push-merged-local and
  // remote blocks.
  val remoteRequests = partitionBlocksByFetchMode(
    blocksByAddress, localBlocks, hostLocalBlocksByExecutor, pushMergedLocalBlocks)
  // [2] 以随机顺序将远程请求添加到我们的队列中
  // Add the remote requests into our queue in a random order
  fetchRequests   = Utils.randomize(remoteRequests)
  assert((0 ==reqsInFlight) == (0 ==bytesInFlight),
    "expected reqsInFlight = 0 but found reqsInFlight = "  reqsInFlight 
    ", expected bytesInFlight = 0 but found bytesInFlight = "  bytesInFlight)

  // [3] 发送remote fetch请求
  // Send out initial requests for blocks, up to our maxBytesInFlight
  fetchUpToMaxBytes()

  val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum
  val numFetches = remoteRequests.size -fetchRequests.size - numDeferredRequest
  logInfo(s"Started$numFetches remote fetches in${Utils.getUsedTimeNs(startTimeNs)}"  
    (if (numDeferredRequest > 0 ) s", deferred$numDeferredRequest requests" else ""))
  // [4] 支持executor获取local和remote的merge shuffle数据
  // Get Local Blocks
  fetchLocalBlocks(localBlocks)
  logDebug(s"Got local blocks in${Utils.getUsedTimeNs(startTimeNs)}")
  // Get host local blocks if any
  fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)
pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks)
}

在shuffle fetch的迭代器中,获取数据请求有下面四步:

  • [1] 通过不同的获取模式对块进行分区:本地、主机本地和远程块
  • [2] 以随机顺序将远程请求添加到我们的队列中
  • [3] 发送remote fetch请求
  • [4] 获取local blocks
  • [5] 获取host blocks
  • [6] 获取pushMerge的local blocks

划分数据源的请求

代码语言:javascript复制
private[this] def partitionBlocksByFetchMode(
    blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
    localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
    hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
    pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
  ...

val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
  val localExecIds =Set(blockManager.blockManagerId.executorId, fallback)
  for ((address, blockInfos) <- blocksByAddress) {
    checkBlockSizes(blockInfos)
    // [1] 如果是push-merged blocks, 判断其是否是主机的还是远程请求
    if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
      // These are push-merged blocks or shuffle chunks of these blocks.
      if (address.host == blockManager.blockManagerId.host) {
numBlocksToFetch = blockInfos.size
        pushMergedLocalBlocks   = blockInfos.map(_._1)
        pushMergedLocalBlockBytes  = blockInfos.map(_._2).sum
      } else {
        collectFetchRequests(address, blockInfos, collectedRemoteRequests)
      }
     // [2] 如果是localexecIds, 放入localBlocks
    } else if (localExecIds.contains(address.executorId)) {
      val mergedBlockInfos =mergeContinuousShuffleBlockIdsIfNeeded(
        blockInfos.map(info =>FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
numBlocksToFetch = mergedBlockInfos.size
      localBlocks   = mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
      localBlockBytes  = mergedBlockInfos.map(_.size).sum
    // [3] 如果是host本地,并将其放入hostLocalBlocksByExecutor
    } else if (blockManager.hostLocalDirManager.isDefined &&
      address.host == blockManager.blockManagerId.host) {
      val mergedBlockInfos =mergeContinuousShuffleBlockIdsIfNeeded(
        blockInfos.map(info =>FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
numBlocksToFetch = mergedBlockInfos.size
      val blocksForAddress =
        mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
      hostLocalBlocksByExecutor  = address -> blocksForAddress
      numHostLocalBlocks  = blocksForAddress.size
      hostLocalBlockBytes  = mergedBlockInfos.map(_.size).sum
    // [4] 如果是remote请求,收集fetch请求, 每个请求的最大请求数据大小,是max(maxBytesInFlight / 5, 1L),这是为了提高请求的并发度,保证至少向5个不同的节点发送请求获取数据,最大限度地利用各节点的资源
    } else {
      val (_, timeCost) = Utils.timeTakenMs[Unit] {
        collectFetchRequests(address, blockInfos, collectedRemoteRequests)
      }
      logDebug(s"Collected remote fetch requests for$address in$timeCost ms")
    }
  }
  val (remoteBlockBytes, numRemoteBlocks) =
    collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1   y.size, x._2   y.blocks.size))
  val totalBytes = localBlockBytes   remoteBlockBytes   hostLocalBlockBytes  
    pushMergedLocalBlockBytes
  val blocksToFetchCurrentIteration =numBlocksToFetch- prevNumBlocksToFetch
  ...
  this.hostLocalBlocks  = hostLocalBlocksByExecutor.values
    .flatMap { infos => infos.map(info => (info._1, info._3)) }
  collectedRemoteRequests
}
  • [1] 如果是push-merged blocks, 判断其是否是主机的还是远程请求
  • [2] 如果是localexecIds, 放入localBlocks
  • [3] 如果是host本地,并将其放入hostLocalBlocksByExecutor
  • [4] 如果是remote请求,收集fetch请求, 每个请求的最大请求数据大小,是max(maxBytesInFlight / 5, 1L),这是为了提高请求的并发度,保证至少向5个不同的节点发送请求获取数据,最大限度地利用各节点的资源

在划分完数据的请求类别后,会依次的进行remote fetch请求,local blocks请求,host blocks请求和获取pushMerge的local blocks。

那么数据是如何被Fetch的呢?接下来我们看下fetchUpToMaxBytes()方法。

代码语言:javascript复制
private def fetchUpToMaxBytes(): Unit = {
  // [1] 如果是延迟请求,如果可以远程块Fetch同时是未达到处理请求的字节数,进行send请求
  if (deferredFetchRequests.nonEmpty) {
    for ((remoteAddress, defReqQueue) <-deferredFetchRequests) {
      while (isRemoteBlockFetchable(defReqQueue) &&
          !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
        val request = defReqQueue.dequeue()
        logDebug(s"Processing deferred fetch request for$remoteAddress with "
            s"${request.blocks.length} blocks")
        send(remoteAddress, request)
        if (defReqQueue.isEmpty) {
deferredFetchRequests-= remoteAddress
        }
      }
    }
  }

  // [2] 如果正常可以远程Fetch请求,直接send请求;如果达到处理请求的字节,则创建remoteAddress的延迟请求
  // Process any regular fetch requests if possible.
  while (isRemoteBlockFetchable(fetchRequests)) {
    val request = fetchRequests.dequeue()
    val remoteAddress = request.address
    if (isRemoteAddressMaxedOut(remoteAddress, request)) {
      logDebug(s"Deferring fetch request for$remoteAddress with${request.blocks.size} blocks")
      val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
      defReqQueue.enqueue(request)
deferredFetchRequests(remoteAddress) = defReqQueue
    } else {
      send(remoteAddress, request)
    }
  }
}

Fetch请求字节数据:

  • [1] 如果是延迟请求,如果可以远程块Fetch同时是未达到处理请求的字节数,进行send请求
  • [2] 如果正常可以远程Fetch请求,直接send请求;如果达到处理请求的字节,则创建remoteAddress的延迟请求

它会验证该请求是否应被视为延迟。如果是,则将其添加到deferredFetchRequests中。否则,它会继续并从BlockStoreClient实现发送请求(如果启用了 shuffle 服务,则为ExternalBlockStoreClient ,否则为NettyBlockTransferService)。

代码语言:javascript复制
// ShuffleBlockFetcherIterator
private[this] def sendRequest(req: FetchRequest): Unit = {
      // ...
      // [1] 创建了一个**BlockFetchingListener**,在完成请求后会被调用
      val blockFetchingListener = new BlockFetchingListener {
      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
      // ...
      remainingBlocks -= blockId
      results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
      address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
      // ...
      }
      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
        results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
      }
    }

    // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
    // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
    // the data and write it to file directly.
    // [2] 如果请求大小超过可以存储在内存中的请求的最大大小 ,则迭代器通过可选地定义DownloadFileManager来发送获取请求
    if (req.size > maxReqSizeShuffleToMem) {
      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
        blockFetchingListener, this)
    } else {
      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
        blockFetchingListener, null)
    }

在sendRequest中主要进行了以下两个步骤:

  • [1] 创建了一个BlockFetchingListener,在完成请求后会被调用
  • [2] 如果请求大小超过可以存储在内存中的请求的最大大小 ,则迭代器通过可选地定义DownloadFileManager来发送获取请求

Ued.png

首先,ShuffleBlockFetcherIterator迭代器创建了一个BlockFetchingListener,在其中定义成功执行和实现执行后的回调函数,如果成功执行,它会首先为迭代器加synchronized锁,然后将块数据添加到结果变量中。如果发生错误,同样会先加synchronized锁,然后它将添加一个标记类来指示获取失败。

其次,ShuffleBlockFetcherIterator会调用BlockStoreClient的fetchBlocks方法,在调用前会判断请求的内容的大小,如果超过门限,则传参定义DownloadFileManager,它会使得shuffleData将被下载到临时文件。

下面我们看下最终的fetchBlocks是如何实现的?

代码语言:javascript复制
@Override
public void fetchBlocks(
    String host,
    int port,
    String execId,
    String[] blockIds,
    BlockFetchingListener listener,
    DownloadFileManager downloadFileManager) {
  checkInit();
  logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
  try {
    // [1] 首先创建并初始化RetryingBlockFetcher类,用它加载shuffle files
    int maxRetries = transportConf.maxIORetries();
    RetryingBlockTransferor.BlockTransferStarter blockFetchStarter =
        (inputBlockId, inputListener) -> {
          // Unless this client is closed.
          if (clientFactory != null) {
            assert inputListener instanceof BlockFetchingListener :
              "Expecting a BlockFetchingListener, but got "   inputListener.getClass();
            TransportClient client = clientFactory.createClient(host, port, maxRetries > 0);
           // [2] 创建OneForOneBlockFetcher,用其进行下载shuffle Data
            new OneForOneBlockFetcher(client, appId, execId, inputBlockId,
              (BlockFetchingListener) inputListener, transportConf, downloadFileManager).start();
          } else {
            logger.info("This clientFactory was closed. Skipping further block fetch retries.");
          }
        };
      ...
      // [3] 调用OneForOneBlockFetcher的start方法
      blockFetchStarter.createAndStart(blockIds, listener);
    }
}
  • [1] 首先创建并初始化RetryingBlockFetcher类,用它加载shuffle files
  • [2] 创建OneForOneBlockFetcher,用其进行下载shuffle Data

OneForOneBlockFetcher进行Shuffle 数据的下载

OneForOneBlockFetcher是基于RPC通信,从各个Executor端获取shuffle数据,我们首先来简要概述下:

  • 首先,fetcher 会向持有 shuffle 文件的 executor发送FetchShuffleBlocks消息;
  • 其次,executor将register new Stream 同时返回StreamHandle消息到fetcher, 它带有streamId;
  • 在收到StreamHandle响应后,client将stream或load 数据块;
  • 如果downloadFileManager 不为空,则会将结果写入临时文件;对于内存的场景,shuffle bytes将加载到in-memory buffer中;
  • 最终,基于临时文件还是基于内存都会调用sendRequest中定义的BlockFetchingListener回调函数。

itled.png

获取到的shuffle data会被放入到new LinkedBlockingQueue[FetchResult],并调用next()方法。如果所有可用的块数据都已被消耗,迭代器将执行之前提供的 fetchUpToMaxBytes()。

ShuffleBlockFetcherIterator初始化完成后

在ShuffleBlockFetcherIterator初始化完成后,我们再来看看剩余的工作:

代码语言:javascript复制
private class ShuffleFetchCompletionListener(var data: ShuffleBlockFetcherIterator)
  extends TaskCompletionListener {
  override def onTaskCompletion(context: TaskContext): Unit = {
    if (data != null) {
      data.cleanup()locations(blocksByAddress)
      data = null
    }
  }
  def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context)
}

在ShuffleBlockFetcherIterator初始化完成后,会将其转换为CompletionIterator,在其中主要是进行资源的释放。然后借助于反序列化器将其将shuffle block反序列化为record迭代器。在将其包装为metricIter 同于更新task的metric。之后再将其封装为InterruptibleIterator迭代器。可中断迭代器的作用是每次执行hasNext方法时,它都会分析任务状态并最终终止托管此迭代器的任务。主要用于启用了推测执行的情况。

代码语言:javascript复制
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

def hasNext: Boolean = {
    // TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt
    // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
    // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
    // introduces an expensive read fence.
    context.killTaskIfInterrupted()
    delegate.hasNext
 }

接下来就是reduce端的聚合排序的操作, 注意这里需要在ShuffleDependency中定义, aggregator和keyOrdering,这些操作需要在PairRDDFunctions 中进行定义。

但是在SparkSQL中,它采用的是ShuffleExchangeExec并不会定义 aggregator和keyOrdering,那么Spark SQL是如何实现聚合和排序的呢?

代码语言:javascript复制
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
    ...
  } else {
    interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
  }

val resultIter: Iterator[Product2[K, C]] =dep.keyOrdering match {
    caseSome(keyOrd: Ordering[K]) =>
      val sorter =
        new ExternalSorter[K, C, C](context, ordering =Some(keyOrd), serializer =dep.serializer)
      sorter.insertAllAndUpdateMetrics(aggregatedIter)
    case None =>
      aggregatedIter
  }

其实通过其执行计划可以知道,其会在其中插入Sort算子来实现聚合排序。

到此为止,shuffle reader的大致过程已经走了一遍,但是还有很多的重要细节并没有展开探讨,那么这里就详细总结下整体的流程:

Fetch前的准备

  1. fetch reader 的调用主要是ShuffledRDD和ShuffledRowRDD中,通过传入 不同的partitionspecs给getReader传入不同的调用参数。
  2. 在getReader中会先通过mapOutputTracker获取mapid对应的shuffle文件的位置,然后在通过BlockStoreShuffleReader reader的唯一实现类进行shuffle fetch;
  3. 在Driver端mapOutputTracker记录mapId和对应的文件位置主要由MapOutputTrackerMaster进行维护,在创建mapShuffleStage时会向master tracker中注册shuffleid, 在完成mapStage时会更新对应shuffleId中维护的mapid对应的位置信息。在Executor端从MapOutputTrackerWorker中获取位置信息,如果获取不到会向master tracker发送信息,同步信息过来;

处理Fetch请求

  1. 在BlockStoreShuffleReader中进行fetch时,会先创建ShuffleBlockFetcherIterator, 并将Fetch分为local, host local, remote不同方式;同时在Fetch时也会有些限制,包括每个Excutor阻塞的fetch request数和fetch shuffle数据是否大于分配的内存;如果请求的数据量过多,超过了内存限制,将通过写入临时文件实现;如果网络通信开销太大,fetcher 将停止读取,并在需要下一个 shuffle 块文件时恢复读取。
  2. 最终的Fetch是通过OneForOneBlockFetcher实现的,fetcher 会向持有 shuffle 文件的 executor发送FetchShuffleBlocks消息,executor将register new Stream 同时将数据封装为StreamHandle消息返回到fetcher,client最后再将加载数据块;最终调用BlockFetchingListener回调函数。

Fetch后的处理

  1. reduce端聚合数据:如果map端已经聚合过了,则对读取到的聚合结果进行聚合。如果map端没有聚合,则针对未合并的<k,v>进行聚合。
  2. reduce端排序数据:如果需要对key排序,则进行排序。基于sort的shuffle实现过程中,默认只是按照partitionId排序。在每一个partition内部并没有排序,因此添加了keyOrdering变量,提供是否需要对分区内部的key排序
  3. 另外需要注意的是SparkSQL中并不会设置ShuffleDependency的排序和聚合,而是通过规则在逻辑树中插入Sort算子实现的。

学完Shuffle Reader下面是一些思考题:

  1. 为什么在调用getReader时要根据partitionspecs的不同传递不同的参数?主要的作用是什么?
  2. 远程Fetch和本地Fetch最大的区别是什么?
  3. InterruptibleIterator 和 CompletionIterator 迭代器的作用是什么?

0 人点赞