在 spark window 函数中求和并根据条件重新启动

Sum values and restarting on conditions in spark window functions

如果一列是 'relative',我想对值求和,如果是 'absolute'

,我想重新开始求和

这里我定义了我的数据框:

val df = sc.parallelize(Seq(
  (1, "2018-02-21", 'relative, 3.00),
  (1, "2018-02-22", 'relative, 4.00),
  (1, "2018-02-23", 'absolute, 5.00),
  (1, "2018-02-24", 'relative, 6.00),
  (1, "2018-02-26", 'relative, 8.00)
)).toDF("id", "date", "updateType", "value")

我定义了一个 UDF 来知道什么时候求和什么时候不求和。我想按日期排序,然后在必须时对值求和或求绝对值

val computeValue = udf((previous: java.math.BigDecimal, value: java.math.BigDecimal, updateType: String) => {
  updateType match {
    case "absolute" => value
    case "relative" => previous.add(value)
    case _ => previous
  }
})
val w = Window
  .partitionBy($"id")
  .orderBy($"date")

val result = df.select(
  $"id",
  $"date",
  computeValue(
    lag($"value", 1, 0).over(w),
    $"value",
    $"updateType"
  ).alias("sumValue")
)

实际上returns:

+---+----------+---------+
| id|      date| sumValue|
+---+----------+---------+
|  1|2018-02-21|3.000    |
|  1|2018-02-22|7.000    |
|  1|2018-02-23|5.00     |
|  1|2018-02-24|11.00    |
|  1|2018-02-26|14.00    |
+---+----------+---------+

我正在寻找:

+---+----------+---------+
| id|      date| sumValue|
+---+----------+---------+
|  1|2018-02-21|3.000    |
|  1|2018-02-22|7.000    |
|  1|2018-02-23|5.00     |
|  1|2018-02-24|11.00    |
|  1|2018-02-26|19.00    |
+---+----------+---------+

答案是使用UDAF(User Defined Aggregation Function)进行这种操作。

// Init aggregation function to compute values
val computeValue = new ComputeValue
val w = Window
  .partitionBy($"id")
  .orderBy($"date")

val result = df.select(
  $"id",
  $"date",
  computeValue(
    $"value",
    $"updateType"
  ).over(w).alias("sumValue")
)

其中 ComputeValue UDAF 是:

class ComputeValue extends UserDefinedAggregateFunction {

  // Each row will be of type value: Double - update_type: String
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(
      StructField("value", DoubleType) ::
        StructField("update_type", StringType) :: Nil)

  // Another column where I will keep internal calculations
  override def bufferSchema: StructType = StructType(
    StructField("value", DoubleType) :: Nil
  )

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = 0.0

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = computeValue(buffer, input)
  }

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = computeValue(buffer1, buffer2)
  }

  // Get the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(0)
  }

  private def computeValue(buffer: MutableAggregationBuffer, row: Row): Double = {
    val updateType: String = row.getAs[String](1)
    val prev: Double = buffer.getDouble(0)
    val current: Double = row.getAs[Double](0)

    updateType match {
      case "relative" => prev + current
      case "absolute" => current
      case _ => current
    }
  }
}