有没有scala spark函数实现先groupby再过滤再聚合

Is there a scala spark function to achieve the groupby then filter and then aggregate

我有一个包含状态列表和薪水列表的数据框 state.I 需要按状态分组并找出每个薪水范围内有多少条目(其中有 3 个薪水范围total) 创建一个 Dataframe 并根据 state name 对结果进行排序。 Spark 中是否有任何函数可以实现此目的。

Sample input 

State  salary
------ ------
NY      6
WI      15
NY      11
WI      2
MI      20
NY      15 
 
Result expected is

State    group1   group2  group3
 MI         0       0       1  
 NY         0       1       2
 WI         1       0       1

在哪里

基本上是从类似 scala spark 的东西看的

df.groupBy('STATE').agg(count('*') as group1).where('SALARY' >0 and 'SALARY' <=5)
.agg(count('*') as group2).where('SALARY' >5 and 'SALARY' <=10)
.agg(count('*') as group3).where('SALARY' >10 and 'SALARY' <=20)

解决方案更新 :

解决方案 1: 能够按照下面提供的方法解决,但不确定是否有更简单有效的方法。任何方向? dfWithoutSchema 是输入数据框

val newDf = dfWithoutSchema.withColumn("set1", when($"salary">0 and $"salary" <= 5, 1).otherwise(0)).withColumn("set2", when($"salary">5 and $"salary" <= 10, 1).otherwise(0)).withColumn("set3", when($"salary">10 and $"salary" <= 20, 1).otherwise(0))
val fdf=newDf.groupBy("state").agg(sum("set1") as "group1",sum("set2") as "group2",sum("set3") as "group3").sort("state")

解决方案 2:

val agg_df = df.groupBy("State")
    .agg(
        count(when($"Salary" > 0 && $"Salary" <= 5, $"Salary")).as("group_1"),
        count(when($"Salary" > 5 && $"Salary" <= 10, $"Salary")).as("group_2"),
        count(when($"Salary" > 10 && $"Salary" <= 20, $"Salary")).as("group_3")
    )

您可以指定要 count/sum 合计的条件。

示例:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()
data = [
    {"State": "NY", "Salary": 6},
    {"State": "WI", "Salary": 15},
    {"State": "NY", "Salary": 11},
    {"State": "WI", "Salary": 2},
    {"State": "MI", "Salary": 20},
    {"State": "NY", "Salary": 15},
]
df = spark.createDataFrame(data=data)
cnt_cond = lambda cond: F.sum(F.when(cond, 1).otherwise(0))
df = df.groupBy("State").agg(
    cnt_cond((F.col("Salary") > 0) & (F.col("Salary") <= 5)).alias("group_1"),
    cnt_cond((F.col("Salary") > 5) & (F.col("Salary") <= 10)).alias("group_2"),
    cnt_cond((F.col("Salary") > 10) & (F.col("Salary") <= 20)).alias("group_3"),
)

这里 sumcount 相同,因为它检查条件,returns 1 如果满足条件,否则 0.

使用 Scala:

val agg_df = df.groupBy("State")
    .agg(
        count(when($"Salary" > 0 && $"Salary" <= 5, $"Salary")).as("group_1"),
        count(when($"Salary" > 5 && $"Salary" <= 10, $"Salary")).as("group_2"),
        count(when($"Salary" > 10 && $"Salary" <= 20, $"Salary")).as("group_3")
    )

结果:

+-----+-------+-------+-------+                                                 
|State|group_1|group_2|group_3|
+-----+-------+-------+-------+
|NY   |0      |1      |2      |
|WI   |1      |0      |1      |
|MI   |0      |0      |1      |
+-----+-------+-------+-------+

可以使用sumcase函数组成的表达式。

data = [
    ('NY', 6),
    ('WI', 15),
    ('NY', 11),
    ('WI', 2),
    ('MI', 20),
    ('NY', 15)
]
df = spark.createDataFrame(data, ['State', 'salary'])
df = df.groupBy('State').agg(F.expr('sum(case when salary>0 and salary<=5 then 1 else 0 end)').alias('group1'),
                             F.expr('sum(case when salary>5 and salary<=10 then 1 else 0 end)').alias('group2'),
                             F.expr('sum(case when salary>10 and salary<=20 then 1 else 0 end)').alias('group3'))
df.show(truncate=False)