在 spark window 函数中求和并根据条件重新启动
Sum values and restarting on conditions in spark window functions
如果一列是 'relative'
,我想对值求和,如果是 'absolute'
,我想重新开始求和
这里我定义了我的数据框:
val df = sc.parallelize(Seq(
(1, "2018-02-21", 'relative, 3.00),
(1, "2018-02-22", 'relative, 4.00),
(1, "2018-02-23", 'absolute, 5.00),
(1, "2018-02-24", 'relative, 6.00),
(1, "2018-02-26", 'relative, 8.00)
)).toDF("id", "date", "updateType", "value")
我定义了一个 UDF 来知道什么时候求和什么时候不求和。我想按日期排序,然后在必须时对值求和或求绝对值
val computeValue = udf((previous: java.math.BigDecimal, value: java.math.BigDecimal, updateType: String) => {
updateType match {
case "absolute" => value
case "relative" => previous.add(value)
case _ => previous
}
})
val w = Window
.partitionBy($"id")
.orderBy($"date")
val result = df.select(
$"id",
$"date",
computeValue(
lag($"value", 1, 0).over(w),
$"value",
$"updateType"
).alias("sumValue")
)
实际上returns:
+---+----------+---------+
| id| date| sumValue|
+---+----------+---------+
| 1|2018-02-21|3.000 |
| 1|2018-02-22|7.000 |
| 1|2018-02-23|5.00 |
| 1|2018-02-24|11.00 |
| 1|2018-02-26|14.00 |
+---+----------+---------+
我正在寻找:
+---+----------+---------+
| id| date| sumValue|
+---+----------+---------+
| 1|2018-02-21|3.000 |
| 1|2018-02-22|7.000 |
| 1|2018-02-23|5.00 |
| 1|2018-02-24|11.00 |
| 1|2018-02-26|19.00 |
+---+----------+---------+
答案是使用UDAF(User Defined Aggregation Function)进行这种操作。
// Init aggregation function to compute values
val computeValue = new ComputeValue
val w = Window
.partitionBy($"id")
.orderBy($"date")
val result = df.select(
$"id",
$"date",
computeValue(
$"value",
$"updateType"
).over(w).alias("sumValue")
)
其中 ComputeValue UDAF 是:
class ComputeValue extends UserDefinedAggregateFunction {
// Each row will be of type value: Double - update_type: String
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(
StructField("value", DoubleType) ::
StructField("update_type", StringType) :: Nil)
// Another column where I will keep internal calculations
override def bufferSchema: StructType = StructType(
StructField("value", DoubleType) :: Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = 0.0
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = computeValue(buffer, input)
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = computeValue(buffer1, buffer2)
}
// Get the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0)
}
private def computeValue(buffer: MutableAggregationBuffer, row: Row): Double = {
val updateType: String = row.getAs[String](1)
val prev: Double = buffer.getDouble(0)
val current: Double = row.getAs[Double](0)
updateType match {
case "relative" => prev + current
case "absolute" => current
case _ => current
}
}
}
如果一列是 'relative'
,我想对值求和,如果是 'absolute'
这里我定义了我的数据框:
val df = sc.parallelize(Seq(
(1, "2018-02-21", 'relative, 3.00),
(1, "2018-02-22", 'relative, 4.00),
(1, "2018-02-23", 'absolute, 5.00),
(1, "2018-02-24", 'relative, 6.00),
(1, "2018-02-26", 'relative, 8.00)
)).toDF("id", "date", "updateType", "value")
我定义了一个 UDF 来知道什么时候求和什么时候不求和。我想按日期排序,然后在必须时对值求和或求绝对值
val computeValue = udf((previous: java.math.BigDecimal, value: java.math.BigDecimal, updateType: String) => {
updateType match {
case "absolute" => value
case "relative" => previous.add(value)
case _ => previous
}
})
val w = Window
.partitionBy($"id")
.orderBy($"date")
val result = df.select(
$"id",
$"date",
computeValue(
lag($"value", 1, 0).over(w),
$"value",
$"updateType"
).alias("sumValue")
)
实际上returns:
+---+----------+---------+
| id| date| sumValue|
+---+----------+---------+
| 1|2018-02-21|3.000 |
| 1|2018-02-22|7.000 |
| 1|2018-02-23|5.00 |
| 1|2018-02-24|11.00 |
| 1|2018-02-26|14.00 |
+---+----------+---------+
我正在寻找:
+---+----------+---------+
| id| date| sumValue|
+---+----------+---------+
| 1|2018-02-21|3.000 |
| 1|2018-02-22|7.000 |
| 1|2018-02-23|5.00 |
| 1|2018-02-24|11.00 |
| 1|2018-02-26|19.00 |
+---+----------+---------+
答案是使用UDAF(User Defined Aggregation Function)进行这种操作。
// Init aggregation function to compute values
val computeValue = new ComputeValue
val w = Window
.partitionBy($"id")
.orderBy($"date")
val result = df.select(
$"id",
$"date",
computeValue(
$"value",
$"updateType"
).over(w).alias("sumValue")
)
其中 ComputeValue UDAF 是:
class ComputeValue extends UserDefinedAggregateFunction {
// Each row will be of type value: Double - update_type: String
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(
StructField("value", DoubleType) ::
StructField("update_type", StringType) :: Nil)
// Another column where I will keep internal calculations
override def bufferSchema: StructType = StructType(
StructField("value", DoubleType) :: Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = 0.0
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = computeValue(buffer, input)
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = computeValue(buffer1, buffer2)
}
// Get the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0)
}
private def computeValue(buffer: MutableAggregationBuffer, row: Row): Double = {
val updateType: String = row.getAs[String](1)
val prev: Double = buffer.getDouble(0)
val current: Double = row.getAs[Double](0)
updateType match {
case "relative" => prev + current
case "absolute" => current
case _ => current
}
}
}