如何解决 Spark/Scala 中的不可变数据框?
How to work around the immutable data frames in Spark/Scala?
我正在尝试将以下 pyspark 代码转换为 scala。如您所知,scala 中的数据帧是 immutable,这限制了我转换以下代码:
pyspark 代码:
time_frame = ["3m","6m","9m","12m","18m","27m","60m","60m_ab"]
variable_name = ["var1", "var2", "var3"....., "var30"]
train_df = sqlContext.sql("select * from someTable")
for var in variable_name:
for tf in range(1,len(time_frame)):
train_df=train_df.withColumn(str(time_frame[tf]+'_'+var), fn.col(str(time_frame[tf]+'_'+var))+fn.col(str(time_frame[tf-1]+'_'+var)))
因此,正如您在上方看到的那样,table 具有不同的列,用于重新创建更多列。然而,Spark/Scala 中数据帧的 immutable 性质是反对的,你能帮我解决一些问题吗?
这是一种方法,首先使用 for-comprehension
生成由列名对组成的元组列表,然后使用 foldLeft
遍历该列表,通过 trainDF
迭代变换 withColumn
:
import org.apache.spark.sql.functions._
val timeframes: Seq[String] = ???
val variableNames: Seq[String] = ???
val newCols = for {
vn <- variableNames
tf <- 1 until timeframes.size
} yield (timeframes(tf) + "_" + vn, timeframes(tf - 1) + "_" + vn)
val trainDF = spark.sql("""select * from some_table""")
val resultDF = newCols.foldLeft(trainDF)( (accDF, cs) =>
accDF.withColumn(cs._1, col(cs._1) + col(cs._2))
)
要测试以上代码,只需提供示例输入并创建 table some_table
:
val timeframes = Seq("3m", "6m", "9m")
val variableNames = Seq("var1", "var2")
val df = Seq(
(1, 10, 11, 12, 13, 14, 15),
(2, 20, 21, 22, 23, 24, 25),
(3, 30, 31, 32, 33, 34, 35)
).toDF("id", "3m_var1", "6m_var1", "9m_var1", "3m_var2", "6m_var2", "9m_var2")
df.createOrReplaceTempView("some_table")
ResultDF
应如下所示:
resultDF.show
// +---+-------+-------+-------+-------+-------+-------+
// | id|3m_var1|6m_var1|9m_var1|3m_var2|6m_var2|9m_var2|
// +---+-------+-------+-------+-------+-------+-------+
// | 1| 10| 21| 33| 13| 27| 42|
// | 2| 20| 41| 63| 23| 47| 72|
// | 3| 30| 61| 93| 33| 67| 102|
// +---+-------+-------+-------+-------+-------+-------+
我正在尝试将以下 pyspark 代码转换为 scala。如您所知,scala 中的数据帧是 immutable,这限制了我转换以下代码:
pyspark 代码:
time_frame = ["3m","6m","9m","12m","18m","27m","60m","60m_ab"]
variable_name = ["var1", "var2", "var3"....., "var30"]
train_df = sqlContext.sql("select * from someTable")
for var in variable_name:
for tf in range(1,len(time_frame)):
train_df=train_df.withColumn(str(time_frame[tf]+'_'+var), fn.col(str(time_frame[tf]+'_'+var))+fn.col(str(time_frame[tf-1]+'_'+var)))
因此,正如您在上方看到的那样,table 具有不同的列,用于重新创建更多列。然而,Spark/Scala 中数据帧的 immutable 性质是反对的,你能帮我解决一些问题吗?
这是一种方法,首先使用 for-comprehension
生成由列名对组成的元组列表,然后使用 foldLeft
遍历该列表,通过 trainDF
迭代变换 withColumn
:
import org.apache.spark.sql.functions._
val timeframes: Seq[String] = ???
val variableNames: Seq[String] = ???
val newCols = for {
vn <- variableNames
tf <- 1 until timeframes.size
} yield (timeframes(tf) + "_" + vn, timeframes(tf - 1) + "_" + vn)
val trainDF = spark.sql("""select * from some_table""")
val resultDF = newCols.foldLeft(trainDF)( (accDF, cs) =>
accDF.withColumn(cs._1, col(cs._1) + col(cs._2))
)
要测试以上代码,只需提供示例输入并创建 table some_table
:
val timeframes = Seq("3m", "6m", "9m")
val variableNames = Seq("var1", "var2")
val df = Seq(
(1, 10, 11, 12, 13, 14, 15),
(2, 20, 21, 22, 23, 24, 25),
(3, 30, 31, 32, 33, 34, 35)
).toDF("id", "3m_var1", "6m_var1", "9m_var1", "3m_var2", "6m_var2", "9m_var2")
df.createOrReplaceTempView("some_table")
ResultDF
应如下所示:
resultDF.show
// +---+-------+-------+-------+-------+-------+-------+
// | id|3m_var1|6m_var1|9m_var1|3m_var2|6m_var2|9m_var2|
// +---+-------+-------+-------+-------+-------+-------+
// | 1| 10| 21| 33| 13| 27| 42|
// | 2| 20| 41| 63| 23| 47| 72|
// | 3| 30| 61| 93| 33| 67| 102|
// +---+-------+-------+-------+-------+-------+-------+