使用 Spark DataFrame 获取组后所有组的 TopN

get TopN of all groups after group by using Spark DataFrame

我有一个 Spark SQL DataFrame:

user1 item1 rating1
user1 item2 rating2
user1 item3 rating3
user2 item1 rating4
...

如何使用 Scala 按用户分组,然后 return TopN 个项目来自每个组?

相似度代码使用 Python:

df.groupby("user").apply(the_func_get_TopN)

您可以使用rank window功能如下

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{rank, desc}

val n: Int = ???

// Window definition
val w = Window.partitionBy($"user").orderBy(desc("rating"))

// Filter
df.withColumn("rank", rank.over(w)).where($"rank" <= n)

如果您不在意平局,那么您可以将 rank 替换为 row_number