Spark Shuffle模块——Suffle Read过程分析

简介: 在阅读本文之前,请先阅读Spark Sort Based Shuffle内存分析Spark Shuffle Read调用栈如下: 1. org.apache.spark.rdd.ShuffledRDD#compute() 2. org.apache.spark.shuffle.ShuffleManager#getReader() 3. org.apache.spar

在阅读本文之前,请先阅读Spark Sort Based Shuffle内存分析

Spark Shuffle Read调用栈如下:
1. org.apache.spark.rdd.ShuffledRDD#compute()
2. org.apache.spark.shuffle.ShuffleManager#getReader()
3. org.apache.spark.shuffle.hash.HashShuffleReader#read()
4. org.apache.spark.storage.ShuffleBlockFetcherIterator#initialize()
5. org.apache.spark.storage.ShuffleBlockFetcherIterator#splitLocalRemoteBlocks()
org.apache.spark.storage.ShuffleBlockFetcherIterator#sendRequest()
org.apache.spark.storage.ShuffleBlockFetcherIterator#fetchLocalBlocks()

下面是fetchLocalBlocks()方法执行时涉及到的类和对应方法:
6. org.apache.spark.storage.BlockManager#getBlockData()
org.apache.spark.shuffle.hash.ShuffleManager#shuffleBlockResolver()
ShuffleManager有两个子类,如果是HashShuffle 则对应的是org.apache.spark.shuffle.hash.HashShuffleManager#shuffleBlockResolver()方法,该方法返回的是org.apache.spark.shuffle.FileShuffleBlockResolver,再调用FileShuffleBlockResolver#getBlockData()方法返回Block数据
;如果是Sort Shuffle,则对应的是
org.apache.spark.shuffle.hash.SortShuffleManager#shuffleBlockResolver(),该方法返回的是org.apache.spark.shuffle.IndexShuffleBlockResolver,然后再调用IndexShuffleBlockResolver#getBlockData()返回Block数据。

下面是org.apache.spark.storage.ShuffleBlockFetcherIterator#sendRequest()方法执行时涉及到的类和对应方法
7.

org.apache.spark.network.shuffle.ShuffleClient#fetchBlocks
org.apache.spark.network.shuffle.ShuffleClient有两个子类,分别是ExternalShuffleClient及BlockTransferService
,其中org.apache.spark.network.shuffle.BlockTransferService又有两个子类,分别是NettyBlockTransferService和NioBlockTransferService,对应两种不同远程获取Block数据方式,Spark 1.5.2中已经将NioBlockTransferService方式设置为deprecated,在后续版本中将被移除

下面按上述调用栈对各方法进行说明,这里只讲脉络,细节后面再讨论

ShuffledRDD#compute()代码

Task执行时,调用ShuffledRDD的compute方法,其代码如下:

//org.apache.spark.rdd.ShuffledRDD#compute()
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    //通过org.apache.spark.shuffle.ShuffleManager#getReader()方法
    //无论是Sort Shuffle 还是 Hash Shuffle,使用的都是
    //org.apache.spark.shuffle.hash.HashShuffleReader
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

可以看到,其核心逻辑是通过调用ShuffleManager#getReader()方法得到HashShuffleReader对象,然后调用HashShuffleReader#read()方法完成前一Stage中ShuffleMapTask生成的Shuffle 数据的读取。需要说明的是,无论是Hash Shuffle还是Sort Shuffle,使用的都是HashShuffleReader。

HashShuffleReader#read()

跳到HashShuffleReader#read()方法当中,其源码如下:

/** Read the combined key-values for this reduce task */
  override def read(): Iterator[Product2[K, C]] = {
    //创建ShuffleBlockFetcherIterator对象,在其构造函数中会调用initialize()方法
    //该方法中会执行splitLocalRemoteBlocks(),确定数据的读取策略
    //远程数据调用sendRequest()方法读取
    //本地数据调用fetchLocalBlocks()方法读取
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition),
      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

    // Wrap the streams for compression based on configuration
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      blockManager.wrapForCompression(blockId, inputStream)
    }

    val ser = Serializer.getSerializer(dep.serializer)
    val serializerInstance = ser.newInstance()

    // Create a key/value iterator for each stream
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
      // NextIterator. The NextIterator makes sure that close() is called on the
      // underlying InputStream when all records have been read.
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // Update the context task metrics for each record read.
    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map(record => {
        readMetrics.incRecordsRead(1)
        record
      }),
      context.taskMetrics().updateShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) { 
        // 读取Map端已经聚合的数据
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        //读取Reducer端聚合的数据
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // 对输出结果进行排序
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.internalMetricsToAccumulators(
          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
        sorter.iterator
      case None =>
        aggregatedIter
    }
  }

ShuffleBlockFetcherIterator#splitLocalRemoteBlocks()

splitLocalRemoteBlocks()方法确定数据的读取策略,localBlocks变量记录在本地机器的BlockID,remoteBlocks变量则用于记录所有在远程机器上的BlockID。远程数据块被分割成最大为maxSizeInFlight大小的FetchRequests

val remoteRequests = new ArrayBuffer[FetchRequest]

splitLocalRemoteBlocks()方法具有源码如下:

private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
    // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
    // nodes, rather than blocking on reading output from one node.
    //maxBytesInFlight为每次请求的最大数据量,默认值为48M
    //通过SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)进行设置
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)

    // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
    // at most maxBytesInFlight in order to limit the amount of data in flight.
    val remoteRequests = new ArrayBuffer[FetchRequest]

    // Tracks total number of blocks (including zero sized blocks)
    var totalBlocks = 0
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size
      //要获取的数据在本地
      if (address.executorId == blockManager.blockManagerId.executorId) {
        // Filter out zero-sized blocks
        //记录数据在本地的BlockID
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {
       //数据不在本地时
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
          val (blockId, size) = iterator.next()
          // Skip empty blocks
          if (size > 0) {
            curBlocks += ((blockId, size))
            //记录数据在远程机器上的BlockID
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          } else if (size < 0) {
            throw new BlockException(blockId, "Negative block size " + size)
          }
          if (curRequestSize >= targetRequestSize) {
            // Add this FetchRequest
            remoteRequests += new FetchRequest(address, curBlocks)
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            logDebug(s"Creating fetch request of $curRequestSize at $address")
            curRequestSize = 0
          }
        }
        // Add in the final request
        if (curBlocks.nonEmpty) {
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
    remoteRequests
  }

ShuffleBlockFetcherIterator#fetchLocalBlocks()

fetchLocalBlocks()方法进行本地Block的读取,调用的是BlockManager的getBlockData方法,其源代码如下:

private[this] def fetchLocalBlocks() {
    val iter = localBlocks.iterator
    while (iter.hasNext) {
      val blockId = iter.next()
      try {
        //调用BlockManager的getBlockData方法
        val buf = blockManager.getBlockData(blockId)
        shuffleMetrics.incLocalBlocksFetched(1)
        shuffleMetrics.incLocalBytesRead(buf.size)
        buf.retain()
        results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
      } catch {
        case e: Exception =>
          // If we see an exception, stop immediately.
          logError(s"Error occurred while fetching local blocks", e)
          results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
          return
      }
    }
  }

跳转到BlockManager的getBlockData方法,可以看到其源代码如下:

override def getBlockData(blockId: BlockId): ManagedBuffer = {
          if (blockId.isShuffle) {   
//先调用的是ShuffleManager的shuffleBlockResolver方法,得到ShuffleBlockResolver
//然后再调用其getBlockData方法   shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
          } else {
            val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
              .asInstanceOf[Option[ByteBuffer]]
            if (blockBytesOpt.isDefined) {
              val buffer = blockBytesOpt.get
        new NioManagedBuffer(buffer)
      } else {
        throw new BlockNotFoundException(blockId.toString)
      }
    }
  }

org.apache.spark.shuffle.hash.ShuffleManager#shuffleBlockResolver()方法获取相应的ShuffleBlockResolver,如果是Hash Shuffle,则
是org.apache.spark.shuffle.FileShuffleBlockResolver,如果是Sort Shuffle则org.apache.spark.shuffle.IndexShuffleBlockResolver。然后调用对应ShuffleBlockResolver的getBlockData方法,返回对应的FileSegment。
FileShuffleBlockResolver#getBlockData方法源码如下:

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
    //对应Hash Shuffle中的Shuffle Consolidate Files机制生成的文件
    if (consolidateShuffleFiles) { 
      // Search all file groups associated with this shuffle.
      val shuffleState = shuffleStates(blockId.shuffleId)
      val iter = shuffleState.allFileGroups.iterator
      while (iter.hasNext) {
        val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId)
        if (segmentOpt.isDefined) {
          val segment = segmentOpt.get
          return new FileSegmentManagedBuffer(
            transportConf, segment.file, segment.offset, segment.length)
        }
      }
      throw new IllegalStateException("Failed to find shuffle block: " + blockId)
    } else {
      //普通的Hash Shuffle机制生成的文件
      val file = blockManager.diskBlockManager.getFile(blockId)
      new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
    }
  }

IndexShuffleBlockResolver#getBlockData方法源码如下:

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
    // The block is actually going to be a range of a single map output file for this map, so
    // find out the consolidated file, then the offset within that from our index
    //使用shuffleId和mapId,获取对应索引文件
    val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)

    val in = new DataInputStream(new FileInputStream(indexFile))
    try {
      //定位到本次Block对应的数据位置
      ByteStreams.skipFully(in, blockId.reduceId * 8)
      //数据起始位置
      val offset = in.readLong()
      //数据结束位置
      val nextOffset = in.readLong()
      //返回FileSegment
      new FileSegmentManagedBuffer(
        transportConf,
        getDataFile(blockId.shuffleId, blockId.mapId),
        offset,
        nextOffset - offset)
    } finally {
      in.close()
    }
  }

ShuffleBlockFetcherIterator#sendRequest()

sendRequest()方法用于从远程机器上获取数据

 private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size

    // so we can look up the size of each blockID
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
    val blockIds = req.blocks.map(_._1.toString)

    val address = req.address
    //使用ShuffleClient的fetchBlocks方法获取数据
    //有两种ShuffleClient,分别是ExternalShuffleClient和BlockTransferService
    //默认为BlockTransferService
    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
      new BlockFetchingListener {
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
          // Only add the buffer to results queue if the iterator is not zombie,
          // i.e. cleanup() has not been called yet.
          if (!isZombie) {
            // Increment the ref count because we need to pass this to a different thread.
            // This needs to be released after use.
            buf.retain()
            results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
            shuffleMetrics.incRemoteBytesRead(buf.size)
            shuffleMetrics.incRemoteBlocksFetched(1)
          }
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }

        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          results.put(new FailureFetchResult(BlockId(blockId), address, e))
        }
      }
    )
  }

通过上面的代码可以看到,代码使用的是shuffleClient.fetchBlocks进行远程Block数据的获取,org.apache.spark.network.shuffle.ShuffleClient有两个子类,分别是ExternalShuffleClient和BlockTransferService,而org.apache.spark.network.shuffle.BlockTransferService又有两个子类,分别是NettyBlockTransferService和NioBlockTransferService,shuffleClient 对象在 org.apache.spark.storage.BlockManager定义,其源码如下:

// org.apache.spark.storage.BlockManager中定义的shuffleClient 
 private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
    //使用ExternalShuffleClient获取远程Block数据
    val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
    new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
      securityManager.isSaslEncryptionEnabled())
  } else {
    //使用NettyBlockTransferService或NioBlockTransferService获取远程Block数据
    blockTransferService
  }

代码中的blockTransferService在SparkEnv中被初始化,具体如下:

 //org.apache.spark.SparkEnv中初始化blockTransferService 
 val blockTransferService =
      conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
        case "netty" =>
          new NettyBlockTransferService(conf, securityManager, numUsableCores)
        case "nio" =>
          logWarning("NIO-based block transfer service is deprecated, " +
            "and will be removed in Spark 1.6.0.")
          new NioBlockTransferService(conf, securityManager)
      }
目录
相关文章
|
5月前
|
机器学习/深度学习 分布式计算 算法
Spark快速大数据分析PDF下载读书分享推荐
《Spark快速大数据分析》适合初学者,聚焦Spark实用技巧,同时深入核心概念。作者团队来自Databricks,书中详述Spark 3.0新特性,结合机器学习展示大数据分析。Spark是大数据分析的首选工具,本书助你驾驭这一利器。[PDF下载链接][1]。 ![Spark Book Cover][2] [1]: https://zhangfeidezhu.com/?p=345 [2]: https://i-blog.csdnimg.cn/direct/6b851489ad1944548602766ea9d62136.png#pic_center
184 1
Spark快速大数据分析PDF下载读书分享推荐
|
1月前
|
分布式计算 监控 大数据
如何优化Spark中的shuffle操作?
【10月更文挑战第18天】
|
2月前
|
SQL 分布式计算 Serverless
EMR Serverless Spark:一站式全托管湖仓分析利器
本文根据2024云栖大会阿里云 EMR 团队负责人李钰(绝顶) 演讲实录整理而成
154 2
|
2月前
|
设计模式 数据采集 分布式计算
企业spark案例 —出租车轨迹分析
企业spark案例 —出租车轨迹分析
86 0
|
2月前
|
SQL 分布式计算 大数据
大数据-97 Spark 集群 SparkSQL 原理详细解析 Broadcast Shuffle SQL解析过程(一)
大数据-97 Spark 集群 SparkSQL 原理详细解析 Broadcast Shuffle SQL解析过程(一)
57 0
|
2月前
|
SQL 分布式计算 算法
大数据-97 Spark 集群 SparkSQL 原理详细解析 Broadcast Shuffle SQL解析过程(二)
大数据-97 Spark 集群 SparkSQL 原理详细解析 Broadcast Shuffle SQL解析过程(二)
85 0
|
5月前
|
弹性计算 分布式计算 Serverless
全托管一站式大规模数据处理和分析Serverless平台 | EMR Serverless Spark 评测
【7月更文挑战第6天】全托管一站式大规模数据处理和分析Serverless平台 | EMR Serverless Spark 评测
23718 42
|
6月前
|
机器学习/深度学习 数据采集 分布式计算
基于spark的大数据分析预测地震受灾情况的系统设计
基于spark的大数据分析预测地震受灾情况的系统设计
163 1
|
6月前
|
分布式计算 定位技术 Scala
使用spark基于出租车GPS数据实现车辆数量统计以及北京每个城区的车辆位置点数分析
使用spark基于出租车GPS数据实现车辆数量统计以及北京每个城区的车辆位置点数分析
120 0
|
1月前
|
分布式计算 大数据 Apache
ClickHouse与大数据生态集成:Spark & Flink 实战
【10月更文挑战第26天】在当今这个数据爆炸的时代,能够高效地处理和分析海量数据成为了企业和组织提升竞争力的关键。作为一款高性能的列式数据库系统,ClickHouse 在大数据分析领域展现出了卓越的能力。然而,为了充分利用ClickHouse的优势,将其与现有的大数据处理框架(如Apache Spark和Apache Flink)进行集成变得尤为重要。本文将从我个人的角度出发,探讨如何通过这些技术的结合,实现对大规模数据的实时处理和分析。
109 2
ClickHouse与大数据生态集成:Spark & Flink 实战