重写 LogicalPlan 以从聚合中下推 udf
Rewrite LogicalPlan to push down udf from aggregate
我定义了一个将输入值加一的UDF,命名为"inc",这是我的udf
的代码
spark.udf.register("inc", (x: Long) => x + 1)
这是我的测试sql
val df = spark.sql("select sum(inc(vals)) from data")
df.explain(true)
df.show()
这是那个sql
的优化方案
== Optimized Logical Plan ==
Aggregate [sum(inc(vals#4L)) AS sum(inc(vals))#7L]
+- LocalRelation [vals#4L]
我想重写计划,并从 "sum" 中提取 "inc",就像 python udf 所做的那样。
所以,这就是我想要的优化方案。
Aggregate [sum(inc_val#6L) AS sum(inc(vals))#7L]
+- Project [inc(vals#4L) AS inc_val#6L]
+- LocalRelation [vals#4L]
我发现源代码文件 "ExtractPythonUDFs.scala" 提供了与 PythonUDF 类似的功能,但它插入了一个名为 "ArrowEvalPython" 的新节点,这是 pythonudf 的逻辑计划.
== Optimized Logical Plan ==
Aggregate [sum(pythonUDF0#7L) AS sum(inc(vals))#4L]
+- Project [pythonUDF0#7L]
+- ArrowEvalPython [inc(vals#0L)], [pythonUDF0#7L], 200
+- Repartition 10, true
+- RelationV2[vals#0L] parquet file:/tmp/vals.parquet
我想插入的只是一个"project node",我不想定义一个新的节点
这是我项目的测试代码
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
object RewritePlanTest {
case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {
def collectUDFs(e: Expression): Seq[Expression] = e match {
case udf: ScalaUDF => Seq(udf)
case _ => e.children.flatMap(collectUDFs)
}
override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case agg@Aggregate(g, a, _) if (g.isEmpty && a.length == 1) =>
val udfs = agg.expressions.flatMap(collectUDFs)
println("================")
udfs.foreach(println)
val test = udfs(0).isInstanceOf[NamedExpression]
println(s"cast ScalaUDF to NamedExpression = ${test}")
println("================")
agg
case _ => plan
}
}
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession
.builder()
.master("local[*]")
.appName("Rewrite plan test")
.withExtensions(e => e.injectOptimizerRule(UdfRule))
.getOrCreate()
val input = Seq(100L, 200L, 300L)
import spark.implicits._
input.toDF("vals").createOrReplaceTempView("data")
spark.udf.register("inc", (x: Long) => x + 1)
val df = spark.sql("select sum(inc(vals)) from data")
df.explain(true)
df.show()
spark.stop()
}
}
我从 Aggregate
节点提取 ScalaUDF
,
因为 Project
节点所需的参数是 Seq[NamedExpression]
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
但无法将 ScalaUDF
转换为 NamedExpression
、
所以我不知道如何构建 Project
节点。
有人可以给我一些建议吗?
谢谢。
好吧,我终于想办法回答这个问题了。
虽然 ScalaUDF
不能转换为 NamedExpression
,但是 Alias
可以。
所以,我从 ScalaUDF
创建 Alias
,然后构造 Project
。
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpectsInputTypes, ExprId, Expression, NamedExpression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{AbstractDataType, DataType}
import scala.collection.mutable
object RewritePlanTest {
case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {
def collectUDFs(e: Expression): Seq[Expression] = e match {
case udf: ScalaUDF => Seq(udf)
case _ => e.children.flatMap(collectUDFs)
}
override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case agg@Aggregate(g, a, c) if g.isEmpty && a.length == 1 => {
val udfs = agg.expressions.flatMap(collectUDFs)
if (udfs.isEmpty) {
agg
} else {
val alias_udf = for (i <- 0 until udfs.size) yield Alias(udfs(i), s"udf${i}")()
val alias_set = mutable.HashMap[Expression, Attribute]()
val proj = Project(alias_udf, c)
alias_set ++= udfs.zip(proj.output)
val new_agg = agg.withNewChildren(Seq(proj)).transformExpressionsUp {
case udf: ScalaUDF if alias_set.contains(udf) => alias_set(udf)
}
println("====== new agg ======")
println(new_agg)
new_agg
}
}
case _ => plan
}
}
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession
.builder()
.master("local[*]")
.appName("Rewrite plan test")
.withExtensions(e => e.injectOptimizerRule(UdfRule))
.getOrCreate()
val input = Seq(100L, 200L, 300L)
import spark.implicits._
input.toDF("vals").createOrReplaceTempView("data")
spark.udf.register("inc", (x: Long) => x + 1)
val df = spark.sql("select sum(inc(vals)) from data where vals > 100")
// val plan = df.queryExecution.analyzed
// println(plan)
df.explain(true)
df.show()
spark.stop()
}
}
此代码输出我想要的 LogicalPlan。
====== new agg ======
Aggregate [sum(udf0#9L) AS sum(inc(vals))#7L]
+- Project [inc(vals#4L) AS udf0#9L]
+- LocalRelation [vals#4L]
我定义了一个将输入值加一的UDF,命名为"inc",这是我的udf
的代码spark.udf.register("inc", (x: Long) => x + 1)
这是我的测试sql
val df = spark.sql("select sum(inc(vals)) from data")
df.explain(true)
df.show()
这是那个sql
的优化方案== Optimized Logical Plan ==
Aggregate [sum(inc(vals#4L)) AS sum(inc(vals))#7L]
+- LocalRelation [vals#4L]
我想重写计划,并从 "sum" 中提取 "inc",就像 python udf 所做的那样。 所以,这就是我想要的优化方案。
Aggregate [sum(inc_val#6L) AS sum(inc(vals))#7L]
+- Project [inc(vals#4L) AS inc_val#6L]
+- LocalRelation [vals#4L]
我发现源代码文件 "ExtractPythonUDFs.scala" 提供了与 PythonUDF 类似的功能,但它插入了一个名为 "ArrowEvalPython" 的新节点,这是 pythonudf 的逻辑计划.
== Optimized Logical Plan ==
Aggregate [sum(pythonUDF0#7L) AS sum(inc(vals))#4L]
+- Project [pythonUDF0#7L]
+- ArrowEvalPython [inc(vals#0L)], [pythonUDF0#7L], 200
+- Repartition 10, true
+- RelationV2[vals#0L] parquet file:/tmp/vals.parquet
我想插入的只是一个"project node",我不想定义一个新的节点
这是我项目的测试代码
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
object RewritePlanTest {
case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {
def collectUDFs(e: Expression): Seq[Expression] = e match {
case udf: ScalaUDF => Seq(udf)
case _ => e.children.flatMap(collectUDFs)
}
override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case agg@Aggregate(g, a, _) if (g.isEmpty && a.length == 1) =>
val udfs = agg.expressions.flatMap(collectUDFs)
println("================")
udfs.foreach(println)
val test = udfs(0).isInstanceOf[NamedExpression]
println(s"cast ScalaUDF to NamedExpression = ${test}")
println("================")
agg
case _ => plan
}
}
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession
.builder()
.master("local[*]")
.appName("Rewrite plan test")
.withExtensions(e => e.injectOptimizerRule(UdfRule))
.getOrCreate()
val input = Seq(100L, 200L, 300L)
import spark.implicits._
input.toDF("vals").createOrReplaceTempView("data")
spark.udf.register("inc", (x: Long) => x + 1)
val df = spark.sql("select sum(inc(vals)) from data")
df.explain(true)
df.show()
spark.stop()
}
}
我从 Aggregate
节点提取 ScalaUDF
,
因为 Project
节点所需的参数是 Seq[NamedExpression]
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
但无法将 ScalaUDF
转换为 NamedExpression
、
所以我不知道如何构建 Project
节点。
有人可以给我一些建议吗?
谢谢。
好吧,我终于想办法回答这个问题了。
虽然 ScalaUDF
不能转换为 NamedExpression
,但是 Alias
可以。
所以,我从 ScalaUDF
创建 Alias
,然后构造 Project
。
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpectsInputTypes, ExprId, Expression, NamedExpression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{AbstractDataType, DataType}
import scala.collection.mutable
object RewritePlanTest {
case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {
def collectUDFs(e: Expression): Seq[Expression] = e match {
case udf: ScalaUDF => Seq(udf)
case _ => e.children.flatMap(collectUDFs)
}
override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case agg@Aggregate(g, a, c) if g.isEmpty && a.length == 1 => {
val udfs = agg.expressions.flatMap(collectUDFs)
if (udfs.isEmpty) {
agg
} else {
val alias_udf = for (i <- 0 until udfs.size) yield Alias(udfs(i), s"udf${i}")()
val alias_set = mutable.HashMap[Expression, Attribute]()
val proj = Project(alias_udf, c)
alias_set ++= udfs.zip(proj.output)
val new_agg = agg.withNewChildren(Seq(proj)).transformExpressionsUp {
case udf: ScalaUDF if alias_set.contains(udf) => alias_set(udf)
}
println("====== new agg ======")
println(new_agg)
new_agg
}
}
case _ => plan
}
}
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession
.builder()
.master("local[*]")
.appName("Rewrite plan test")
.withExtensions(e => e.injectOptimizerRule(UdfRule))
.getOrCreate()
val input = Seq(100L, 200L, 300L)
import spark.implicits._
input.toDF("vals").createOrReplaceTempView("data")
spark.udf.register("inc", (x: Long) => x + 1)
val df = spark.sql("select sum(inc(vals)) from data where vals > 100")
// val plan = df.queryExecution.analyzed
// println(plan)
df.explain(true)
df.show()
spark.stop()
}
}
此代码输出我想要的 LogicalPlan。
====== new agg ======
Aggregate [sum(udf0#9L) AS sum(inc(vals))#7L]
+- Project [inc(vals#4L) AS udf0#9L]
+- LocalRelation [vals#4L]