使用Spark DataFrame groupby时如何获取其他列?

How to get other columns when using Spark DataFrame groupby?

当我像这样使用 DataFrame groupby 时:

df.groupBy(df("age")).agg(Map("id"->"count"))

我只会得到一个包含 "age" 和 "count(id)" 列的 DataFrame,但在 df 中,还有许多其他列,例如 "name".

总而言之,我想得到 MySQL、

中的结果

"select name,age,count(id) from df group by age"

Spark中使用groupby怎么办?

长话短说,总的来说,您必须将聚合结果与原始结果合并 table。 Spark SQL 遵循与大多数主要数据库(PostgreSQL、Oracle、MS SQL Server)相同的 pre-SQL:1999 约定,它不允许额外的列在聚合查询中。

由于像计数结果这样的聚合没有明确定义,并且在支持此类查询的系统中行为往往会有所不同,您可以使用 firstlast 等任意聚合来包含其他列。

在某些情况下,您可以使用 selectagg 替换为 window 函数和后续 where,但根据上下文,它可能会非常昂贵。

你可以这样做:

示例数据:

name    age id
abc     24  1001
cde     24  1002
efg     22  1003
ghi     21  1004
ijk     20  1005
klm     19  1006
mno     18  1007
pqr     18  1008
rst     26  1009
tuv     27  1010
pqr     18  1012
rst     28  1013
tuv     29  1011
df.select("name","age","id").groupBy("name","age").count().show();

输出:

    +----+---+-----+
    |name|age|count|
    +----+---+-----+
    | efg| 22|    1|
    | tuv| 29|    1|
    | rst| 28|    1|
    | klm| 19|    1|
    | pqr| 18|    2|
    | cde| 24|    1|
    | tuv| 27|    1|
    | ijk| 20|    1|
    | abc| 24|    1|
    | mno| 18|    1|
    | ghi| 21|    1|
    | rst| 26|    1|
    +----+---+-----+

在执行 groupBy 后获取所有列的一种方法是使用连接函数。

feature_group = ['name', 'age']
data_counts = df.groupBy(feature_group).count().alias("counts")
data_joined = df.join(data_counts, feature_group)

data_joined 现在将包含所有列,包括计数值。

这个解决方案可能会有帮助。

from pyspark.sql import SQLContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import functions as F
from pyspark.sql import Window

    name_list = [(101, 'abc', 24), (102, 'cde', 24), (103, 'efg', 22), (104, 'ghi', 21),
                 (105, 'ijk', 20), (106, 'klm', 19), (107, 'mno', 18), (108, 'pqr', 18),
                 (109, 'rst', 26), (110, 'tuv', 27), (111, 'pqr', 18), (112, 'rst', 28), (113, 'tuv', 29)]

age_w = Window.partitionBy("age")
name_age_df = sqlContext.createDataFrame(name_list, ['id', 'name', 'age'])

name_age_count_df = name_age_df.withColumn("count", F.count("id").over(age_w)).orderBy("count")
name_age_count_df.show()

输出:

+---+----+---+-----+
| id|name|age|count|
+---+----+---+-----+
|109| rst| 26|    1|
|113| tuv| 29|    1|
|110| tuv| 27|    1|
|106| klm| 19|    1|
|103| efg| 22|    1|
|104| ghi| 21|    1|
|105| ijk| 20|    1|
|112| rst| 28|    1|
|101| abc| 24|    2|
|102| cde| 24|    2|
|107| mno| 18|    3|
|111| pqr| 18|    3|
|108| pqr| 18|    3|
+---+----+---+-----+

您需要记住聚合函数会减少行数,因此您需要使用减少函数指定您想要的行名称。如果你想保留一个组的所有行(警告!这可能会导致爆炸或倾斜分区)你可以将它们收集为一个列表。然后,您可以使用 UDF(用户定义的函数)根据您的标准减少它们,在我的示例中是金钱。然后使用另一个 UDF 从单个减少的行扩展列。 出于此答案的目的,我假设您希望保留最有钱人的名字。

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType

import scala.collection.mutable


object TestJob3 {

def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  (1, "Moe",  "Slap",  2.0, 18),
  (2, "Larry",  "Spank",  3.0, 15),
  (3, "Curly",  "Twist", 5.0, 15),
  (4, "Laurel", "Whimper", 3.0, 9),
  (5, "Hardy", "Laugh", 6.0, 18),
  (6, "Charley",  "Ignore",   5.0, 5)
).toDF("id", "name", "requisite", "money", "age")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByMoney, rawSchema)

val nameUdf = udf(extractName, StringType)

val aggDf = rawDf
  .groupBy("age")
  .agg(
    count(struct("*")).as("count"),
    max(col("money")),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .withColumn("name", nameUdf($"short"))
  .drop("horizontal")

aggDf.printSchema

aggDf.show(false)

}

def reduceByMoney= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val money1 = r1.getAs[Double]("money")
  val money2 = r2.getAs[Double]("money")

  val r3 = money1 match {
    case a if a >= money2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}

def extractName = (x: Any) => {

  val d = x.asInstanceOf[GenericRowWithSchema]

  d.getAs[String]("name")
}
}

这是输出

+---+-----+----------+----------------------------+-------+
|age|count|max(money)|short                       |name   |
+---+-----+----------+----------------------------+-------+
|5  |1    |5.0       |[6, Charley, Ignore, 5.0, 5]|Charley|
|15 |2    |5.0       |[3, Curly, Twist, 5.0, 15]  |Curly  |
|9  |1    |3.0       |[4, Laurel, Whimper, 3.0, 9]|Laurel |
|18 |2    |6.0       |[5, Hardy, Laugh, 6.0, 18]  |Hardy  |
+---+-----+----------+----------------------------+-------+

聚合函数减少组内指定列的行值。如果您希望保留其他行值,则需要实施缩减逻辑来指定每个值来自的行。例如,将第一行的所有值都保留为年龄的最大值。为此,您可以使用 UDAF(用户定义的聚合函数)来减少组内的行数。

import org.apache.spark.sql._
import org.apache.spark.sql.functions._


object AggregateKeepingRowJob {

  def main (args: Array[String]): Unit = {

    val sparkSession = SparkSession
      .builder()
      .appName(this.getClass.getName.replace("$", ""))
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      (1L, "Moe",  "Slap",  2.0, 18),
      (2L, "Larry",  "Spank",  3.0, 15),
      (3L, "Curly",  "Twist", 5.0, 15),
      (4L, "Laurel", "Whimper", 3.0, 15),
      (5L, "Hardy", "Laugh", 6.0, 15),
      (6L, "Charley",  "Ignore",   5.0, 5)
    ).toDF("id", "name", "requisite", "money", "age")

    rawDf.show(false)
    rawDf.printSchema

    val maxAgeUdaf = new KeepRowWithMaxAge

    val aggDf = rawDf
      .groupBy("age")
      .agg(
        count("id"),
        max(col("money")),
        maxAgeUdaf(
          col("id"),
          col("name"),
          col("requisite"),
          col("money"),
          col("age")).as("KeepRowWithMaxAge")
      )

    aggDf.printSchema
    aggDf.show(false)

  }


}

UDAF:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
  StructType(
    StructField("store", StringType) ::
    StructField("prod", StringType) ::
    StructField("amt", DoubleType) ::
    StructField("units", IntegerType) :: Nil
  )

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
  StructField("store", StringType) ::
  StructField("prod", StringType) ::
  StructField("amt", DoubleType) ::
  StructField("units", IntegerType) :: Nil
)


// This is the output type of your aggregation function.
override def dataType: DataType =
  StructType((Array(
    StructField("store", StringType),
    StructField("prod", StringType),
    StructField("amt", DoubleType),
    StructField("units", IntegerType)
  )))

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
  buffer(0) = ""
  buffer(1) = ""
  buffer(2) = 0.0
  buffer(3) = 0
}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

  val amt = buffer.getAs[Double](2)
  val candidateAmt = input.getAs[Double](2)

  amt match {
    case a if a < candidateAmt =>
      buffer(0) = input.getAs[String](0)
      buffer(1) = input.getAs[String](1)
      buffer(2) = input.getAs[Double](2)
      buffer(3) = input.getAs[Int](3)
    case _ =>
  }
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

  buffer1(0) = buffer2.getAs[String](0)
  buffer1(1) = buffer2.getAs[String](1)
  buffer1(2) = buffer2.getAs[Double](2)
  buffer1(3) = buffer2.getAs[Int](3)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
  buffer
}
}

这是我遇到的一个例子 spark-workshop

val populationDF = spark.read
                .option("infer-schema", "true")
                .option("header", "true")
                .format("csv").load("file:///databricks/driver/population.csv")
                .select('name, regexp_replace(col("population"), "\s", "").cast("integer").as("population"))

val maxPopulationDF = populationDF.agg(max('population).as("populationmax"))

To get other columns, I do a simple join between the original DF and the aggregated one

populationDF.join(maxPopulationDF,populationDF.col("population") === maxPopulationDF.col("populationmax")).select('name, 'populationmax).show()

此 pyspark 代码 select 是每个 A 组的最大值([AB]-组合)的 B 值(如果一组中存在多个最大值,则随机选择一个)。

A 在您的情况下将是 age 并且 B 您没有分组但仍然想要 select.

的任何列
df = spark.createDataFrame([
    [1, 1, 0.2],
    [1, 1, 0.9],
    [1, 2, 0.6],
    [1, 2, 0.5],
    [1, 2, 0.6],
    [2, 1, 0.2],
    [2, 2, 0.1],
], ["group", "A", "B"])

out = (
    df
    .withColumn("AB", F.struct("A", "B"))
    .groupby("group")
    # F.max(AB) selects AB-combinations with max `A`. If more
    # than one combination remains the one with max `B` is selected. If
    # after this identical combinations remain, a single one of them is picked
    # randomly.
    .agg(F.max("AB").alias("max_AB"))
    .select("group", F.expr("max_AB.B"))
)
out.show()

输出

+-----+---+
|group|  B|
+-----+---+
|    1|0.6|
|    2|0.1|
+-----+---+

#solved #working 解决方案

借助@Azmisov 在此线程中的评论生成了此解决方案 代码示例取自 https://sparkbyexamples.com/spark/using-groupby-on-dataframe/

问题:在使用 dataframe 的 spark scala 中,当使用 groupby 和 max 时,它返回的数据框仅包含 groupby 和 max 中使用的列。如何获取所有列?或者可以说如何获得非 groupby 列?

解决方案:请通过完整示例获取所有包含 groupby 和 max

的列
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._ //{col, lit, when, to_timestamp}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Column

val spark = SparkSession
              .builder()
              .appName("app-name")
              .master("local[*]")
              .getOrCreate()
              
    spark.sparkContext.setLogLevel("ERROR")
    import spark.implicits._
    
    val simpleData = Seq(("James","Sales","NY",90000,34,10000),
      ("Michael","Sales","NY",86000,56,20000),
      ("Robert","Sales","CA",81000,30,23000),
      ("Maria","Finance","CA",90000,24,23000),
      ("Raman","Finance","CA",99000,40,24000),
      ("Scott","Finance","NY",83000,36,19000),
      ("Jen","Finance","NY",79000,53,15000),
      ("Jeff","Marketing","CA",80000,25,18000),
      ("Kumar","Marketing","NY",91000,50,21000)
    )
    val df = simpleData.toDF("employee_name","department","state","salary","age","bonus")
    df.show()

生成 df 时给出以下输出。

output :
+-------------+----------+-----+------+---+-----+
|employee_name|department|state|salary|age|bonus|
+-------------+----------+-----+------+---+-----+
|        James|     Sales|   NY| 90000| 34|10000|
|      Michael|     Sales|   NY| 86000| 56|20000|
|       Robert|     Sales|   CA| 81000| 30|23000|
|        Maria|   Finance|   CA| 90000| 24|23000|
|        Raman|   Finance|   CA| 99000| 40|24000|
|        Scott|   Finance|   NY| 83000| 36|19000|
|          Jen|   Finance|   NY| 79000| 53|15000|
|         Jeff| Marketing|   CA| 80000| 25|18000|
|        Kumar| Marketing|   NY| 91000| 50|21000|
+-------------+----------+-----+------+---+-----+

下面的代码给出了列名不合适但仍然可以使用的输出:

val dfwithmax = df.groupBy("department").agg(max("salary"), first("employee_name"), first("state"), first("age"), first("bonus"))
dfwithmax.show()

+----------+-----------+---------------------------+-------------------+-----------------+-------------------+
|department|max(salary)|first(employee_name, false)|first(state, false)|first(age, false)|first(bonus, false)|
+----------+-----------+---------------------------+-------------------+-----------------+-------------------+
|     Sales|      90000|                      James|                 NY|               34|              10000|
|   Finance|      99000|                      Maria|                 CA|               24|              23000|
| Marketing|      91000|                       Jeff|                 CA|               25|              18000|
+----------+-----------+---------------------------+-------------------+-----------------+-------------------+

为了使列名合适,您可以使用下面给出的列名

val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
dfwithmax1.show()

output:
+----------+------+-------------+-----+---+-----+
|department|salary|employee_name|state|age|bonus|
+----------+------+-------------+-----+---+-----+
|     Sales| 90000|        James|   NY| 34|10000|
|   Finance| 99000|        Maria|   CA| 24|23000|
| Marketing| 91000|         Jeff|   CA| 25|18000|
+----------+------+-------------+-----+---+-----+

如果您仍想更改数据框列的顺序,可以按以下方式完成

val reOrderedColumnName : Array[String] = Array("employee_name", "department", "state", "salary", "age", "bonus")
val orderedDf = dfwithmax1.select(reOrderedColumnName.head, reOrderedColumnName.tail: _*)
orderedDf.show()

完整代码在一起:

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Column

object test {
    def main(args: Array[String]): Unit = {
        /** spark session object */
        val spark = SparkSession.builder().appName("app-name").master("local[*]")
                  .getOrCreate()
                  
        spark.sparkContext.setLogLevel("ERROR")
        import spark.implicits._
        
        val simpleData = Seq(("James","Sales","NY",90000,34,10000),
          ("Michael","Sales","NY",86000,56,20000),
          ("Robert","Sales","CA",81000,30,23000),
          ("Maria","Finance","CA",90000,24,23000),
          ("Raman","Finance","CA",99000,40,24000),
          ("Scott","Finance","NY",83000,36,19000),
          ("Jen","Finance","NY",79000,53,15000),
          ("Jeff","Marketing","CA",80000,25,18000),
          ("Kumar","Marketing","NY",91000,50,21000)
        )
        val df = simpleData.toDF("employee_name","department","state","salary","age","bonus")
        df.show()

        val dfwithmax = df.groupBy("department").agg(max("salary"), first("employee_name"), first("state"), first("age"), first("bonus"))
        dfwithmax.show()
        
        val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
        dfwithmax1.show()
        
        val reOrderedColumnName : Array[String] = Array("employee_name", "department", "state", "salary", "age", "bonus")
        val orderedDf = dfwithmax1.select(reOrderedColumnName.head, reOrderedColumnName.tail: _*)
        orderedDf.show()
    }
}

full output :
+-------------+----------+-----+------+---+-----+
|employee_name|department|state|salary|age|bonus|
+-------------+----------+-----+------+---+-----+
|        James|     Sales|   NY| 90000| 34|10000|
|      Michael|     Sales|   NY| 86000| 56|20000|
|       Robert|     Sales|   CA| 81000| 30|23000|
|        Maria|   Finance|   CA| 90000| 24|23000|
|        Raman|   Finance|   CA| 99000| 40|24000|
|        Scott|   Finance|   NY| 83000| 36|19000|
|          Jen|   Finance|   NY| 79000| 53|15000|
|         Jeff| Marketing|   CA| 80000| 25|18000|
|        Kumar| Marketing|   NY| 91000| 50|21000|
+-------------+----------+-----+------+---+-----+

+----------+-----------+---------------------------+------------------------+-------------------+-----------------+-------------------+
|department|max(salary)|first(employee_name, false)|first(department, false)|first(state, false)|first(age, false)|first(bonus, false)|
+----------+-----------+---------------------------+------------------------+-------------------+-----------------+-------------------+
|     Sales|      90000|                      James|                   Sales|                 NY|               34|              10000|
|   Finance|      99000|                      Maria|                 Finance|                 CA|               24|              23000|
| Marketing|      91000|                       Jeff|               Marketing|                 CA|               25|              18000|
+----------+-----------+---------------------------+------------------------+-------------------+-----------------+-------------------+

+----------+------+-------------+----------+-----+---+-----+
|department|salary|employee_name|department|state|age|bonus|
+----------+------+-------------+----------+-----+---+-----+
|     Sales| 90000|        James|     Sales|   NY| 34|10000|
|   Finance| 99000|        Maria|   Finance|   CA| 24|23000|
| Marketing| 91000|         Jeff| Marketing|   CA| 25|18000|
+----------+------+-------------+----------+-----+---+-----+

例外情况:

Exception in thread "main" org.apache.spark.sql.AnalysisException: Reference 'department' is ambiguous, could be: department, department.;

这意味着你有两次部门列。它在 groupby 或 max 中使用,您在第一个(“部门”)中也将其称为“部门”。

例如(请最后勾选):

val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("department") as "department", first("state") as "state", first("age") as "age",first("bonus") as "bonus")

谢谢!觉得有帮助请点个赞