将一行分成两行并虚拟一些列

Split a row into two and dummy some columns

我需要拆分一行并通过更改日期列创建一个新行,并将 amt 列设置为零,如下例所示:

Input:  
+---+-----------------------+-----------------------+-----+
|KEY|START_DATE             |END_DATE               |Amt  |
+---+-----------------------+-----------------------+-----+
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|100.0|
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|200.0|
|0  |2017-10-30T00:00:00.000|2017-11-02T23:59:59.000|67.5 |->Split row based on start & date end date is between "2017-10-31T23:59:59" condition
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|55.3 |
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|22.2 |
|1  |2017-10-30T00:00:00.000|2017-11-01T23:59:59.000|11.0 |->Split row based on start & date end date is between "2017-10-31T23:59:59" condition
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|87.33|
+---+-----------------------+-----------------------+-----+

如果“2017-10-31T23:59:59”在行 start_date 和 end_date 之间,则通过将 end_date 更改为一行将行拆分为两行行和 start_date 为另一行。并使新行的金额为零,如下所示:

期望的输出:

+---+-----------------------+-----------------------+-----+---+
|KEY|START_DATE             |END_DATE               |Amt  |Ind|
+---+-----------------------+-----------------------+-----+---+
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|100.0|N  |
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|200.0|N  |

|0  |2017-10-30T00:00:00.000|2017-10-30T23:59:59.998|67.5 |N  |->parent row (changed the END_DATE)     
|0  |2017-10-30T23:59:59.999|2017-11-02T23:59:59.000|0.0  |Y  |->splitted new row(changed the START_DATE and Amt=0.0)          

|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|55.3 |N  |     
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|22.2 |N  |

|1  |2017-10-30T00:00:00.000|2017-10-30T23:59:59.998|11.0 |N  |->parent row (changed the END_DATE)    
|1  |2017-10-30T23:59:59.999|2017-11-01T23:59:59.000|0.0  |Y  |->splitted new row(changed the START_DATE and Amt=0.0)     

|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|87.33|N  |     
+---+-----------------------+-----------------------+-----+---+

我已经尝试了下面的代码并且能够复制行,但无法即时更新行。

val df1Columns = Seq("KEY", "START_DATE", "END_DATE", "Amt")

  val df1Schema = new StructType(df1Columns.map(c => StructField(c, StringType, nullable = false)).toArray)
  val input1: Array[String] = Seq("0", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "100.0").toArray;
  val row1: Row = Row.fromSeq(input1)
  val input2: Array[String] = Seq("0", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "200.0").toArray;
  val row2: Row = Row.fromSeq(input2)
  val input3: Array[String] = Seq("0", "2017-10-30T00:00:00.000", "2017-11-0123:59:59.000", "67.5").toArray;
  val row3: Row = Row.fromSeq(input3)
  val input4: Array[String] = Seq("0", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "55.3").toArray;
  val row4: Row = Row.fromSeq(input4)
  val input5: Array[String] = Seq("1", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "22.2").toArray;
  val row5: Row = Row.fromSeq(input5)
  val input6: Array[String] = Seq("1", "2017-10-30T00:00:00.000", "2017-11-0123:59:59.000", "11.0").toArray;
  val row6: Row = Row.fromSeq(input6)
  val input7: Array[String] = Seq("1", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "87.33").toArray;
  val row7: Row = Row.fromSeq(input7)

  val rdd: RDD[Row] = spark.sparkContext.parallelize(Seq(row1, row2, row3, row4, row5, row6, row7))
  val df: DataFrame = spark.createDataFrame(rdd, df1Schema)

  //----------------------------------------------------------------

def encoder(columns: Seq[String]): Encoder[Row] = RowEncoder(StructType(columns.map(StructField(_, StringType, nullable = true))))
val outputColumns = Seq("KEY", "START_DATE", "END_DATE", "Amt","Ind")

  val result = df.groupByKey(r => r.getAs[String]("KEY"))
    .flatMapGroups((_, rowsForAkey) => {
      var result: List[Row] = List()
      for (row <- rowsForAkey) {
        val qrDate = "2017-10-31T23:59:59"
        val currRowStartDate = row.getAs[String]("START_DATE")
        val rowEndDate = row.getAs[String]("END_DATE")
        if (currRowStartDate <= qrDate && qrDate <= rowEndDate) //Quota
        {
          val rLayer = row
          result = result :+ rLayer
        }
        val originalRow = row
        result = result :+ originalRow
      }
      result
      })(encoder(df1Columns)).toDF

  df.show(false)
  result.show(false)

这是我的代码输出:

+---+-----------------------+-----------------------+-----+
|KEY|START_DATE             |END_DATE               |Amt  |
+---+-----------------------+-----------------------+-----+
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|100.0|     
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|200.0|     
|0  |2017-10-30T00:00:00.000|2017-11-0123:59:59.000 |67.5 |     
|0  |2017-10-30T00:00:00.000|2017-11-0123:59:59.000 |67.5 |     
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|55.3 |     
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|22.2 |     
|1  |2017-10-30T00:00:00.000|2017-11-0123:59:59.000 |11.0 |     
|1  |2017-10-30T00:00:00.000|2017-11-0123:59:59.000 |11.0 |     
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|87.33|     
+---+-----------------------+-----------------------+-----+

您似乎在复制行,而不是更改它们。

您可以将 flatMapGroups 函数的内部替换为类似以下内容:

rowsForAKey.flatMap{ row => 
  val qrDate = "2017-10-31T23:59:59"
  val currRowStartDate = row.getAs[String]("START_DATE")
  val rowEndDate = row.getAs[String]("END_DATE")
  if (currRowStartDate <= qrDate && qrDate <= rowEndDate) //Quota
  {
    val splitDate = endOfDay(currRowStartDate)
    // need to build two rows
    val parentRow = Row(row(0), row(1), splitDate, row(3), "Y")
    val splitRow = Row(row(0), splitDate, row(2), 0.0, "N")
    List(parentRow, splitRow)
  }
  else {
    List(row)
  }
}

基本上,任何时候你有一个 for 循环在 Scala 中构建这样的列表,它真的是你想要的 mapflatMap。在这里,它是 flatMap 因为每一行都会给我们结果中的一个或两个元素。我假设您引入了一个函数 endOfDay 来生成正确的时间戳。

我知道您可能会以 DataFrame 的方式读取数据,但我确实想提供使用 Dataset[Some Case Class] 的想法——它基本上是 drop-in 替换(你基本上是将 DataFrame 视为 Dataset[Row],毕竟它就是这样),我认为这会让事情更容易阅读,而且你会得到 type-checking。

另外请注意,如果您导入 spark.implicits._,您应该不需要编码器——一切看起来都是字符串或浮点数,并且这些编码器可用。

我建议你使用 内置函数 而不是通过这种复杂的 rdd 方式。

我使用了内置函数,例如lit来填充常量和udf函数来更改日期列中的时间

主题是将 dataframes 一分为二,最后 union (我已经评论了代码的清晰度)

import org.apache.spark.sql.functions._
//udf function to change the time
def changeTimeInDate = udf((toCopy: String, withCopied: String)=> withCopied.split("T")(0)+"T"+toCopy.split("T")(1))

//creating Ind column with N populated and saving in temporaty dataframe
val indDF = df.withColumn("Ind", lit("N"))

//filtering out the rows that match the condition mentioned in the question and then changing the Amt column and Ind column and START_DATE
val duplicatedDF = indDF.filter($"START_DATE" <= "2017-10-31T23:59:59" && $"END_DATE" >= "2017-10-31T23:59:59")
  .withColumn("Amt", lit("0.0"))
  .withColumn("Ind", lit("Y"))
  .withColumn("START_DATE", changeTimeInDate($"END_DATE", $"START_DATE"))

//Changing the END_DATE and finally merging both
val result = indDF.withColumn("END_DATE", changeTimeInDate($"START_DATE", $"END_DATE"))
  .union(duplicatedDF)

你应该有想要的输出

+---+-----------------------+-----------------------+-----+---+
|KEY|START_DATE             |END_DATE               |Amt  |Ind|
+---+-----------------------+-----------------------+-----+---+
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|100.0|N  |
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|55.3 |N  |
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|200.0|N  |
|0  |2017-10-30T00:00:00.000|2017-11-01T00:00:00.000|67.5 |N  |
|0  |2017-10-30T23:59:59.000|2017-11-01T23:59:59.000|0.0  |Y  |
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|22.2 |N  |
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|87.33|N  |
|1  |2017-10-30T00:00:00.000|2017-11-01T00:00:00.000|11.0 |N  |
|1  |2017-10-30T23:59:59.000|2017-11-01T23:59:59.000|0.0  |Y  |
+---+-----------------------+-----------------------+-----+---+