SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(4)

简介: SPARK中的wholeStageCodegen全代码生成--以aggregate代码生成为例说起(4)

背景

本文基于 SPARK 3.3.0

从一个unit test来探究SPARK Codegen的逻辑,


  test("SortAggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).agg(max(col("id")), avg(col("id")))
    withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {
      val plan = df.queryExecution.executedPlan
      assert(plan.exists(p =>
        p.isInstanceOf[WholeStageCodegenExec] &&
          p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))
      assert(df.collect() === Array(Row(9, 4.5)))
    }
  }

该sql形成的执行计划第一部分的全代码生成部分如下:

   WholeStageCodegen     
   +- *(1) SortAggregate(key=[], functions=[partial_max(id#0L), partial_avg(id#0L)], output=[max#12L, sum#13, count#14L])
      +- *(1) Range (0, 10, step=1, splits=2)    


分析


第一阶段wholeStageCodegen

第一阶段的代码生成涉及到SortAggregateExec和RangeExec的produce和consume方法,这里一一来分析:

第一阶段wholeStageCodegen数据流如下:

WholeStageCodegenExec      SortAggregateExec(partial)     RangeExec        
  =========================================================================
  -> execute()
      |
   doExecute() --------->   inputRDDs() -----------------> inputRDDs() 
      |
   doCodeGen()
      |
      +----------------->   produce()
                              |
                           doProduce() 
                              |
                           doProduceWithoutKeys() -------> produce()
                                                              |
                                                          doProduce()
                                                              |
                           doConsume()<------------------- consume()
                              |
                           doConsumeWithoutKeys()
                              |并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用
   doConsume()  <--------  consume()

RangeExec的consume方法

final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = {
    val inputVarsCandidate =
      if (outputVars != null) {
        assert(outputVars.length == output.length)
        // outputVars will be used to generate the code for UnsafeRow, so we should copy them
        outputVars.map(_.copy())
      } else {
        assert(row != null, "outputVars and row cannot both be null.")
        ctx.currentVars = null
        ctx.INPUT_ROW = row
        output.zipWithIndex.map { case (attr, i) =>
          BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
        }
      }
    val inputVars = inputVarsCandidate match {
      case stream: Stream[ExprCode] => stream.force
      case other => other
    }
    val rowVar = prepareRowVar(ctx, row, outputVars)
    // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
    // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
    // generate code of `rowVar` manually.
    ctx.currentVars = inputVars
    ctx.INPUT_ROW = null
    ctx.freshNamePrefix = parent.variablePrefix
    val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
    // Under certain conditions, we can put the logic to consume the rows of this operator into
    // another function. So we can prevent a generated function too long to be optimized by JIT.
    // The conditions:
    // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
    // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
    //    all variables in output (see `requireAllOutput`).
    // 3. The number of output variables must less than maximum number of parameters in Java method
    //    declaration.
    val confEnabled = conf.wholeStageSplitConsumeFuncByOperator
    val requireAllOutput = output.forall(parent.usedInputs.contains(_))
    val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)
    val consumeFunc = if (confEnabled && requireAllOutput
        && CodeGenerator.isValidParamLength(paramLength)) {
      constructDoConsumeFunction(ctx, inputVars, row)
    } else {
      parent.doConsume(ctx, inputVars, rowVar)
    }
    s"""
       |${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")}
       |$evaluated
       |$consumeFunc
     """.stripMargin
  }

其中参数outputVars为传入的rangeExc产生的value


val inputVarsCandidate =和val inputVars =

对于outputVars 不为空的情况下,直接copy复制一份outputVars值作为输入的变量

如果outputVars为空,而row不为空的情况下,则说明传入的是InteralRow类型的变量,需要调用InteralRow对应的方法获取对应的值


val rowVar = prepareRowVar(ctx, row, outputVars)

这部分在RangeExec中不会用到,这里不讲解(因为rangExec这里数据流会走向constructDoConsumeFunction这里)


ctx.currentVars = inputVars ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix

这里是为了对evaluateRequiredVariables方法做铺垫,因为


val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)

其中这里的output 为 Range.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes

inputVars为range_value_0

parent.usedInputs为AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)),和output一样,也就是Range.getOutputAttrs,即StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes

因为inputVars的code为空,所以 evaluated对于该inputVars计算也为空


val confEnabled val requireAllOutput

这里的两个条件都是 TRUE


val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)

计算表达式的长度,对于LONG和DOUBLE类型长度为2,其他的为1,因为range_value_0是LONG类型,所以总的长度为3


val consumeFunc =confEnabled && requireAllOutput&& CodeGenerator.isValidParamLength(paramLength)

这里的三个条件都满足,所以数据流向constructDoConsumeFunction方法,如下:

private def constructDoConsumeFunction(
    ctx: CodegenContext,
    inputVars: Seq[ExprCode],
    row: String): String = {
  val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
  val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)
  val doConsume = ctx.freshName("doConsume")
  ctx.currentVars = inputVarsInFunc
  ctx.INPUT_ROW = null
  val doConsumeFuncName = ctx.addNewFunction(doConsume,
    s"""
       | private void $doConsume(${params.mkString(", ")}) throws java.io.IOException {
       |   ${parent.doConsume(ctx, inputVarsInFunc, rowVar)}
       | }
     """.stripMargin)
  s"""
     | $doConsumeFuncName(${args.mkString(", ")});
   """.stripMargin
}

其中inputVars为range_value_0

row为NULL


val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)

构造 函数实参,形参,以及形参ExprCode变量,分别为range_value_0,long sortAgg_expr_0_0,sortAgg_expr_0_0

val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)

这里是构造UnsafeRow类型的变量便于传给parent进行消费 ,其中 row为NULL,inputVarsInFunc为sortAgg_expr_0_0

private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = {
  if (row != null) {
    ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow]))
  } else {
    if (colVars.nonEmpty) {
      val colExprs = output.zipWithIndex.map { case (attr, i) =>
        BoundReference(i, attr.dataType, attr.nullable)
      }
      val evaluateInputs = evaluateVariables(colVars)
      // generate the code to create a UnsafeRow
      ctx.INPUT_ROW = row
      ctx.currentVars = colVars
      val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
      val code = code"""
        |$evaluateInputs
        |${ev.code}
       """.stripMargin
      ExprCode(code, FalseLiteral, ev.value)
    } else {
      // There are no columns
      ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow]))
    }
  }
}

对于val colExprs =

这块是针对当前物理计划的输出(output)与变量值进行绑定,对于RangeExec来说output的值为Range.getOutputAttrs,即StructType(StructField(“id”, LongType, nullable = false) :: Nil).toAttributes ,而当前rangexec的对应的变量为range_value_0


val evaluateInputs = evaluateVariables(colVars)

对于不是直接赋值的变量,而是通过计算得到的变量,则需要进行提前计算,在这里不需要计算。


val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)

这部分是产生UnsafeRow类型的变量,这个UnsafeRow类型的变量里包含了rangExec的产生的变量rang_value_0

里面具体的细节,这里先忽略,以后会有具体的文章分析。


ExprCode(code, FalseLiteral, ev.value)

这里就返回ExprCode类型的数据结构,

其中code如下:range_mutableStateArray_0[0].reset();range_mutableStateArray_0[0].write(0, sortAgg_expr_0_0);

ev.value如下:range_mutableStateArray_0[0].getRow()


val doConsume = ctx.freshName(“doConsume”)

构建函数的名字,这里为sortAgg_doConsume_0


val doConsumeFuncName =

构造函数调用,其中主要调用的是parent.doConsume(ctx, inputVarsInFunc, rowVar)方法,

注意:这里的rowVar在SortAggregateExec中不会被用到,但是在WholeStageCodeGenExec中会被用到


最后的s"""${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")}$evaluated 则是组装代码

相关文章
|
存储 人工智能 大数据
The Past, Present and Future of Apache Flink
本文整理自阿里云开源大数据负责人王峰(莫问)在 Flink Forward Asia 2024 上海站主论坛开场的分享,今年正值 Flink 开源项目诞生的第 10 周年,借此时机,王峰回顾了 Flink 在过去 10 年的发展历程以及 Flink社区当前最新的技术成果,最后展望下一个十年 Flink 路向何方。
793 33
The Past, Present and Future of Apache Flink
|
人工智能 大数据 Apache
Flink Forward Asia 2024 即将盛大开幕!
Flink Forward Asia 2024是由Apache官方授权的技术大会,聚焦流式湖仓、流批一体、AI大模型等热点方向,旨在分享Flink社区最新动态及实践经验,是Flink开发者和使用者不容错过的盛会。大会不仅探讨了Flink在实时大数据分析中的应用,还深入讨论了Data+AI领域的新成果,如基于Flink和Elasticsearch的企业级高级RAG架构设计,展示了Flink在多模态数据处理、实时数据向量化等方面的强大能力。
|
JSON Java API
在 Java 中解析 JSON ArrayList 的详细指南
【8月更文挑战第23天】
519 1
|
SQL 存储 HIVE
Hive中的分桶表是什么?请解释其作用和使用场景。
Hive中的分桶表是什么?请解释其作用和使用场景。
553 0
|
Java Python
【已解决】RuntimeError Java gateway process exited before sending its port number
【已解决】RuntimeError Java gateway process exited before sending its port number
606 0
|
存储 数据处理
计算机的发展史与计算机硬件组成(上)
计算机的发展史与计算机硬件组成
553 0
|
SQL 分布式计算 大数据
大数据SQL数据倾斜与数据膨胀的优化与经验总结
目前市面上大数据查询分析引擎层出不穷,但在业务使用过程中,大多含有性能瓶颈的SQL,主要集中在数据倾斜与数据膨胀问题中。本文结合业界对大数据SQL的使用与优化,尝试给出相对系统性的解决方案。
14075 5
|
SQL 存储 分布式数据库
hive中的索引
hive中的索引
631 0
|
Kubernetes 调度 Docker
k8s教程(pod篇)-优先级调度
k8s教程(pod篇)-优先级调度
415 0
|
存储 缓存 安全
Shiro-全面详解(学习总结---从入门到深化)(上)
Shiro是apache旗下的一个开源安全框架,它可以帮助我们完成身 份认证,授权、加密、会话管理等功能。
581 0
Shiro-全面详解(学习总结---从入门到深化)(上)