Spark dataframes:根据另一列的值提取一列

Spark dataframes: Extract a column based on the value of another column

我有一个数据框,其中包含带有联合价目表的交易:

+----------+----------+------+-------+-------+
|   paid   | currency | EUR  |  USD  |  GBP  |
+----------+----------+------+-------+-------+
|   49.5   |   EUR    | 99   |  79   |  69   |
+----------+----------+------+-------+-------+

客户已支付 49.5 欧元,如 "currency" 列所示。我现在想将支付的价格与价目表中的价格进行比较。

因此我需要根据 "currency" 的值访问正确的列,如下所示:

df.withColumn("saved", df.col(df.col($"currency")) - df.col("paid"))

我希望会变成

df.withColumn("saved", df.col("EUR") - df.col("paid"))

然而,这失败了。我尝试了所有我能想到的东西,包括 UDF,但一无所获。

我想对此有一些优雅的解决方案吗?有人可以帮忙吗?

我想不出用 DataFrame 做这个的方法,我怀疑是否有简单的方法,但是如果你把那个 table 变成 RDD

// On top of my head, warn if wrong.
// Would be more elegant with match .. case 
def d(l: (Int, String, Int, Int, Int)): Int = {
  if(l._2 == "EUR")
    l._3 - l._1
  else if (l._2 == "USD")
    l._4 - l._1
  else 
    l._5 -l._1
}
val rdd = df.rdd
val diff = rdd.map(r => (r, r(d)))

很可能会引发类型错误,希望您能解决这些问题。

假设列名与 currency 列中的值匹配:

import org.apache.spark.sql.functions.{lit, col, coalesce}
import org.apache.spark.sql.Column 

// Dummy data
val df = sc.parallelize(Seq(
  (49.5, "EUR", 99, 79, 69), (100.0, "GBP", 80, 120, 50)
)).toDF("paid", "currency", "EUR", "USD", "GBP")

// A list of available currencies 
val currencies: List[String] = List("EUR", "USD", "GBP")

// Select listed value
val listedPrice: Column = coalesce(
  currencies.map(c => when($"currency" === c, col(c)).otherwise(lit(null))): _*)

df.select($"*", (listedPrice - $"paid").alias("difference")).show

// +-----+--------+---+---+---+----------+
// | paid|currency|EUR|USD|GBP|difference|
// +-----+--------+---+---+---+----------+
// | 49.5|     EUR| 99| 79| 69|      49.5|
// |100.0|     GBP| 80|120| 50|     -50.0|
// +-----+--------+---+---+---+----------+

SQL 相当于 listedPrice 表达式是这样的:

COALESCE(
  CASE WHEN (currency = 'EUR') THEN EUR ELSE null,
  CASE WHEN (currency = 'USD') THEN USD ELSE null,
  CASE WHEN (currency = 'GBP') THEN GBP ELSE null
)

替代使用foldLeft

import org.apache.spark.sql.functions.when

val listedPriceViaFold = currencies.foldLeft(
  lit(null))((acc, c) => when($"currency" === c, col(c)).otherwise(acc))

df.select($"*", (listedPriceViaFold - $"paid").alias("difference")).show

// +-----+--------+---+---+---+----------+
// | paid|currency|EUR|USD|GBP|difference|
// +-----+--------+---+---+---+----------+
// | 49.5|     EUR| 99| 79| 69|      49.5|
// |100.0|     GBP| 80|120| 50|     -50.0|
// +-----+--------+---+---+---+----------+

其中 listedPriceViaFold 转换为以下 SQL:

CASE
  WHEN (currency = 'GBP') THEN GBP
  ELSE CASE
    WHEN (currency = 'USD') THEN USD
    ELSE CASE
      WHEN (currency = 'EUR') THEN EUR
      ELSE null

不幸的是,我不知道有任何内置函数可以像这样直接表达SQL

CASE currency
    WHEN 'EUR' THEN EUR
    WHEN 'USD' THEN USD
    WHEN 'GBP' THEN GBP
    ELSE null
END

但您可以在原始 SQL.

中使用此构造

我的假设不正确,您可以简单地在 currency 列中添加列名和值之间的映射。

编辑:

如果源支持谓词下推和有效的列修剪,另一个选项可能是有效的,即按货币和联合进行子集化:

currencies.map(
  // for each currency filter and add difference
  c => df.where($"currency" === c).withColumn("difference", $"paid" - col(c))
).reduce((df1, df2) => df1.unionAll(df2)) // Union

相当于SQL这样:

SELECT *,  EUR - paid AS difference FROM df WHERE currency = 'EUR'
UNION ALL
SELECT *,  USD - paid AS difference FROM df WHERE currency = 'USD'
UNION ALL
SELECT *,  GBP - paid AS difference FROM df WHERE currency = 'GBP'