Spark 聚合函数在范围内跳跃或跨步

Spark aggregate function over range with skip or stride

我正在尝试在具有范围的 window 函数上计算 sum() 之类的聚合函数,但我只想包含第 N 行。并且它跳过的内容应该相对于 window 的前面(始终包括 window 中的第一行)

//val df = some Dataframe {symbol,datetime,metric}
val baseWin = Window.partitionBy("symbol").orderBy("datetime")

//This is a plain sum over the window
val plain = sum(col("metric")).over(baseWin.rowsBetween(-12,0))

//This is ALMOST what I want (every 3rd) BUT isn't relative to the window
val almost = sum(when(col("datetime")/lit(DAY) %3 === 0, col("metric")).over(baseWin.rowsBetween(-12,0))

您可以使用 lag。范围定义为:

scala> (0 to 12 by 3)
res1: scala.collection.immutable.Range = Range(0, 3, 6, 9, 12)

您可以对所有滞后进行求和(默认为 0):

val almost = (0 to 12 by 3).map(lag($"metric", _, 0).over(baseWindow)).reduce(_ + _)

示例:

val df = spark.range(24).toDF("metric").withColumn("group", $"metric" > 12)

val baseWindow = Window.partitionBy("group").orderBy("metric")

df.withColumn("almost", almost).show
// +------+-----+------+
// |metric|group|almost|
// +------+-----+------+
// |    13| true|    13| 13
// |    14| true|    14| 14
// |    15| true|    15| 15 
// |    16| true|    29| 16 + 13
// |    17| true|    31| 17 + 14
// |    18| true|    33| 18 + 14
// |    19| true|    48| 19 + 16 + 13
// |    20| true|    51| 20 + 17 + 14
// |    21| true|    54| 21 + 18 + 15
// |    22| true|    70| 22 + 19 + 16 + 13
// |    23| true|    74| 23 + 20 + 17 + 14
// |     0|false|     0| ...
// |     1|false|     1| 1
// |     2|false|     2| 2
// |     3|false|     3| 3
// |     4|false|     5| 4 + 1
// |     5|false|     7| 5 + 2
// |     6|false|     9| 6 + 3
// |     7|false|    12| 7 + 4 + 1
// |     8|false|    15| 8 + 5 + 2
// +------+-----+------+