如何使用 group by 聚合 spark 中的结构数组

How to aggregate an array of struct in spark with group by

我正在使用 Spark 2.1。我有一个具有此架构的数据框:

scala> df.printSchema

|-- id: integer (nullable = true)
|-- sum: integer (nullable = true)
|-- distribution: array (nullable = true)
|    |-- element: struct (containsNull = true)
|    |    |-- lower: integer (nullable = true)
|    |    |-- upper: integer (nullable = true)
|    |    |-- count: integer (nullable = true)

我要汇总:

这里我不能展开数据框,因为我会有重复的行并且不能对 "sum" 列求和。一种可能是分别做分布的求和和聚合,然后通过"id"加入,但是用户定义的函数会更简单。

作为输入,我有:

scala> df.show(false)

+---+---+------------------------------------------------------------+
|id |sum|distribution                                                |
+---+---+------------------------------------------------------------+
|1  |1  |[[0,1,2]]                                                   |
|1  |1  |[[1,2,5]]                                                   |
|1  |7  |[[0,1,1], [1,2,6]]                                          |
|1  |7  |[[0,1,5], [1,2,1], [2,3,1]]                                 |
|2  |1  |[[0,1,1]]                                                   |
|2  |2  |[[0,1,1], [1,2,1]]                                          |
|2  |1  |[[0,1,1]]                                                   |
|2  |1  |[[2,3,1]]                                                   |
|2  |1  |[[0,1,1]]                                                   |
|2  |4  |[[0,1,1], [1,2,1], [2,3,1], [3,4,1]]                        |
+---+---+------------------------------------------------------------+

预期输出:

+---+---+------------------------------------------------------------+
|id |sum|distribution                                                |
+---+---+------------------------------------------------------------+
|1  |16 |[[0,1,8], [1,2,12], [2,3,1]]                                |
|2  |10 |[[0,1,5], [1,2,2], [2,3,3], [3,4,1]]                        |
+---+---+------------------------------------------------------------+

您可以使用这个 UDF:

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, IntegerType, StructField, StructType}


# schema to be used in the UDF to have this format in output this format as output of your 
val schema: ArrayType = ArrayType(StructType(Seq(
      StructField("lower", IntegerType, false),
      StructField("upper", IntegerType, false),
      StructField("count", IntegerType, false)
    )))

val customAggregation = udf((xs: Seq[Seq[Row]]) =>
  xs.flatten.map(row => (
    row.getAs[Int]("lower"),
    row.getAs[Int]("upper"),
    row.getAs[Int]("count")
  ))
    .groupBy(x => (x._1, x._2))
    .mapValues(_.map(_._3).sum).toSeq
    .map(x => (x._1._1, x._1._2, x._2)), schema
)


val df: DataFrame = df_input
.groupBy("id")
.agg(sum("sum"),collect_list("distribution"))
.toDF("id", "sum" ,"distribution")
.withColumn("distribution_agg", flatten(col("distribution")))

结果将是

scala> dfOutput.select("id","sum","distribution_agg").show
+---+---+------------------------------------------------------------+
|id |sum|distribution                                                |
+---+---+------------------------------------------------------------+
|1  |16 |[[0,1,8], [1,2,12], [2,3,1]]                                |
|2  |10 |[[0,1,5], [1,2,2], [2,3,3], [3,4,1]]                        |
+---+---+------------------------------------------------------------+