Spark Scala 创建一个新列,其中包含为每个 cid 添加的先前余额

Spark Scala create a new column which contains addition of previous balance amount for each cid

初始测向:

cid transAmt trasnDate
1 10 2-Aug 
1 20 3-Aug
1 30 3-Aug
2 40 2-Aug
2 50 3-Aug
3 60 4-Aug

输出方向:

cid transAmt trasnDate sumAmt

1 10 2-Aug **10**
1 20 3-Aug **30**
1 30 3-Aug **60**
2 40 2-Aug **40**
2 50 3-Aug **90** 
3 60 4-Aug **60**

我需要一个新列 sumAmt,其中包含每个 cid

使用window sum function求累计和

Example:

df.show()
//+---+------+----------+
//|cid|Amount|transnDate|
//+---+------+----------+
//|  1|    10|     2-Aug|
//|  1|    20|     3-Aug|
//|  2|    40|     2-Aug|
//|  2|    50|     3-Aug|
//|  3|    60|     4-Aug|
//+---+------+----------+

 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.expressions._

 val w= Window.partitionBy("cid").orderBy("Amount","transnDate")

 df.withColumn("sumAmt",sum(col("Amount")).over(w)).show()

//+---+------+----------+------+
//|cid|Amount|transnDate|sumAmt|
//+---+------+----------+------+
//|  1|    10|     2-Aug|    10|
//|  1|    20|     3-Aug|    30|
//|  3|    60|     4-Aug|    60|
//|  2|    40|     2-Aug|    40|
//|  2|    50|     3-Aug|    90|
//+---+------+----------+------+

只需使用简单的 window 表示行之间。

  • Window.unboundedPreceding表示没有下限
  • Window.currentRow 表示当前行(很明显)
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
                
val cidCategory = Window.partitionBy("cid")
                     .orderBy("transDate")
                     .rowsBetween(Window.unboundedPreceding, Window.currentRow)
                
val result = df.withColumn("sumAmt", sum($"transAmt").over(cidCategory))

输出