如何根据先前记录的值更新 spark 数据框的列

How to update column of spark dataframe based on the values of previous record

我在 df 中有三列

Col1,col2,col3

X,x1,x2

Z,z1,z2

Y,

X,x3,x4

P,p1,p2

Q,q1,q2

Y

我想做以下事情 当col1=x时,存储col2和col3的值 并在 col1=y 时将这些列值分配给下一行 预期输出

X,x1,x2

Z,z1,z2

Y,x1,x2

X,x3,x4

P,p1,p2

Q,q1,q2

Y,x3,x4

如有任何帮助,我们将不胜感激 注:-spark 1.6

是的,有一个lag函数需要排序

import org.apache.spark.sql.expressions.Window.orderBy
import org.apache.spark.sql.functions.{coalesce, lag}

case class Temp(a: String, b: Option[String], c: Option[String])

val input = ss.createDataFrame(
  Seq(
    Temp("A", Some("a1"), Some("a2")),
    Temp("D", Some("d1"), Some("d2")),
    Temp("B", Some("b1"), Some("b2")),
    Temp("E", None, None),
    Temp("C", None, None)
  ))

+---+----+----+
|  a|   b|   c|
+---+----+----+
|  A|  a1|  a2|
|  D|  d1|  d2|
|  B|  b1|  b2|
|  E|null|null|
|  C|null|null|
+---+----+----+

val order = orderBy($"a")
input
  .withColumn("b", coalesce($"b", lag($"b", 1).over(order)))
  .withColumn("c", coalesce($"c", lag($"c", 1).over(order)))
  .show()

+---+---+---+
|  a|  b|  c|
+---+---+---+
|  A| a1| a2|
|  B| b1| b2|
|  C| b1| b2|
|  D| d1| d2|
|  E| d1| d2|
+---+---+---+

这是一种使用 Window 函数的方法,步骤如下:

  1. 添加 row-identifying 列(如果已有一列则不需要)并将 non-key 列(可能有很多列)合并为一个
  2. 使用条件空值创建 tmp1,使用 last/rowsBetween Window 函数创建 tmp1 back-fill 最后一个 non-null 值
  3. 根据 colstmp2
  4. 有条件地创建 newcols
  5. 使用 foldLeft
  6. newcols 扩展回单独的列

请注意,此解决方案使用 Window 函数而不进行分区,因此可能不适用于大型数据集。

val df = Seq(
  ("X", "x1", "x2"),
  ("Z", "z1", "z2"),
  ("Y", "", ""),
  ("X", "x3", "x4"),
  ("P", "p1", "p2"),
  ("Q", "q1", "q2"),
  ("Y", "", "")
).toDF("col1", "col2", "col3")

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val colList = df.columns.filter(_ != "col1")

val df2 = df.select($"col1", monotonically_increasing_id.as("id"),
  struct(colList.map(col): _*).as("cols")
)

val df3 = df2.
  withColumn( "tmp1", when($"col1" === "X", $"cols") ).
  withColumn( "tmp2", last("tmp1", ignoreNulls = true).over(
    Window.orderBy("id").rowsBetween(Window.unboundedPreceding, 0)
  ) )

df3.show
// +----+---+-------+-------+-------+
// |col1| id|   cols|   tmp1|   tmp2|
// +----+---+-------+-------+-------+
// |   X|  0|[x1,x2]|[x1,x2]|[x1,x2]|
// |   Z|  1|[z1,z2]|   null|[x1,x2]|
// |   Y|  2|    [,]|   null|[x1,x2]|
// |   X|  3|[x3,x4]|[x3,x4]|[x3,x4]|
// |   P|  4|[p1,p2]|   null|[x3,x4]|
// |   Q|  5|[q1,q2]|   null|[x3,x4]|
// |   Y|  6|    [,]|   null|[x3,x4]|
// +----+---+-------+-------+-------+

val df4 = df3.withColumn( "newcols",
  when($"col1" === "Y", $"tmp2").otherwise($"cols")
).select($"col1", $"newcols")

df4.show
// +----+-------+
// |col1|newcols|
// +----+-------+
// |   X|[x1,x2]|
// |   Z|[z1,z2]|
// |   Y|[x1,x2]|
// |   X|[x3,x4]|
// |   P|[p1,p2]|
// |   Q|[q1,q2]|
// |   Y|[x3,x4]|
// +----+-------+

val dfResult = colList.foldLeft( df4 )(
  (accDF, c) => accDF.withColumn(c, df4(s"newcols.$c"))
).drop($"newcols")

dfResult.show
// +----+----+----+
// |col1|col2|col3|
// +----+----+----+
// |   X|  x1|  x2|
// |   Z|  z1|  z2|
// |   Y|  x1|  x2|
// |   X|  x3|  x4|
// |   P|  p1|  p2|
// |   Q|  q1|  q2|
// |   Y|  x3|  x4|
// +----+----+----+

[更新]

对于 Spark 1.x,last(colName, ignoreNulls) 在 DataFrame API 中不可用。 work-around 是恢复使用 Spark SQL,它在其 last() 方法中支持 ignore-null:

df2.
  withColumn( "tmp1", when($"col1" === "X", $"cols") ).
  createOrReplaceTempView("df2table")
  // might need to use registerTempTable("df2table") instead

val df3 = spark.sqlContext.sql("""
  select col1, id, cols, tmp1, last(tmp1, true) over (
    order by id rows between unbounded preceding and current row
    ) as tmp2
  from df2table
""")