如何解决 Spark/Scala 中的不可变数据框?

How to work around the immutable data frames in Spark/Scala?

我正在尝试将以下 pyspark 代码转换为 scala。如您所知,scala 中的数据帧是 immutable,这限制了我转换以下代码:

pyspark 代码:

 time_frame = ["3m","6m","9m","12m","18m","27m","60m","60m_ab"]
 variable_name = ["var1", "var2", "var3"....., "var30"]
 train_df = sqlContext.sql("select * from someTable")

 for var in variable_name:
     for tf in range(1,len(time_frame)):
         train_df=train_df.withColumn(str(time_frame[tf]+'_'+var), fn.col(str(time_frame[tf]+'_'+var))+fn.col(str(time_frame[tf-1]+'_'+var)))

因此,正如您在上方看到的那样,table 具有不同的列,用于重新创建更多列。然而,Spark/Scala 中数据帧的 immutable 性质是反对的,你能帮我解决一些问题吗?

这是一种方法,首先使用 for-comprehension 生成由列名对组成的元组列表,然后使用 foldLeft 遍历该列表,通过 trainDF 迭代变换 withColumn:

import org.apache.spark.sql.functions._

val timeframes: Seq[String] = ???
val variableNames: Seq[String] = ???

val newCols = for {
  vn <- variableNames
  tf <- 1 until timeframes.size
} yield (timeframes(tf) + "_" + vn, timeframes(tf - 1) + "_" + vn)

val trainDF = spark.sql("""select * from some_table""")

val resultDF = newCols.foldLeft(trainDF)( (accDF, cs) =>
  accDF.withColumn(cs._1, col(cs._1) + col(cs._2))
)

要测试以上代码,只需提供示例输入并创建 table some_table:

val timeframes = Seq("3m", "6m", "9m")
val variableNames = Seq("var1", "var2")

val df = Seq(
  (1, 10, 11, 12, 13, 14, 15),
  (2, 20, 21, 22, 23, 24, 25),
  (3, 30, 31, 32, 33, 34, 35)
).toDF("id", "3m_var1", "6m_var1", "9m_var1", "3m_var2", "6m_var2", "9m_var2")

df.createOrReplaceTempView("some_table")

ResultDF 应如下所示:

resultDF.show
// +---+-------+-------+-------+-------+-------+-------+
// | id|3m_var1|6m_var1|9m_var1|3m_var2|6m_var2|9m_var2|
// +---+-------+-------+-------+-------+-------+-------+
// |  1|     10|     21|     33|     13|     27|     42|
// |  2|     20|     41|     63|     23|     47|     72|
// |  3|     30|     61|     93|     33|     67|    102|
// +---+-------+-------+-------+-------+-------+-------+