如何遍历数据框列中的数组,并在另一个数据框中进行查找 - Scala、Spark

How to iterate though an array in a dataframe column, and do a lookup in another dataframe - Scala, Spark

我有一个数据框 (df1),它看起来像这样:

+-------------------------+
|foods                    |
+-------------------------+
|[Apple, Apple, Banana]   |
|[Apple, Carrot, Broccoli]|
|[Spinach]                |
+-------------------------+

我想在另一个数据框 (df2) 中查找,如下所示:

+------------------+
|food    |category |
+------------------+
|Apple   |Fruit    |
|Carrot  |Vegetable|
|Broccoli|Vegetable|
|Banana  |Fruit    |
|Spinach |Vegetable|
+------------------+

生成的数据框如下所示:

+-------------------------+-----------------------------+---------+
|foods                    |categories                   |has fruit|
+-------------------------+-----------------------------+---------+
|[Apple, Apple, Banana]   |[Fruit, Fruit, Fruit]        |true     |
|[Apple, Carrot, Broccoli]|[Fruit, Vegetable, Vegetable]|true     |
|[Spinach]                |[Vegetable]                  |false    |
+-------------------------+-----------------------------+---------+

我如何才能在 Spark/Scala 中执行此操作?我是 Scala 的新手,所以对代码的解释也可能有帮助。 谢谢!


这是我目前正在使用的代码,但我收到 org.apache.spark.SparkException: Task not serializable 错误 Caused by: java.io.NotSerializableException: org.apache.spark.sql.Column

var schema = df1.schema("foods").dataType

def func = udf((x: Seq[String]) => {
    x.map(x => df2.filter(col("food") === x).select(col("category")).head().getString(0))
}, schema)

df1.withColumn("categories", func($"foods")).show()

我希望得到一些帮助。代码不需要干净。谢谢。


我试过将 df2 变成 Map,并稍微更改了代码:

var mappy = df2.map{ r => (r.getString(0), r.getString(1))}.collect.toMap

var schema = df1.schema("foods").dataType

def func = udf((x: Seq[String]) => {
     x.map(x => mappy.getOrElse(x, ""))
}, schema)

df1.withColumn("categories", func($"foods")).show()

但是,现在我得到这个错误:

java.lang.WhosebugError
  at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1189)
  at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
  at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
  at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
  at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
  at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
  at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
  at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
  at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
  at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
  at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
  at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)

(((repeats)))

抱歉代码乱七八糟。这是为了分析,而不是生产。再次感谢!

我这样准备了你的输入数据帧:

// below import is needed for you to use toDF() method to create DataFrames on the fly.
  import spark.implicits._

  val df1 = List(
    (List("Apple", "Apple", "Banana")),
    (List("Apple", "Carrot", "Broccoli")),
    (List("Spinach"))
  ).toDF("foods")

  val df2 = List(
    ("Apple", "Fruit"),
    ("Carrot", "Vegetable"),
    ("Broccoli", "Vegetable"),
    ("Banana", "Fruit"),
    ("Spinach", "Vegetable")
  ).toDF("food", "category")

我使用 DataFrames 和使用 groupBy 的聚合函数的简单解决方案可以获得所需的输出(如下所示):

// imports 
import org.apache.spark.sql.functions._

// code
  val df1_altered = df1
// explode() : creates a new row for every element in the Array
    .withColumn("each_food_from_list", explode(col("foods")))

  // df1_altered.show()
  val df2_altered = df2
    .withColumn(
      "has_fruit",
      when(col("category").equalTo(lit("Fruit")), true).otherwise(false)
    )
  // df2_altered.show()

  df1_altered
    .join(df2_altered, df1_altered("each_food_from_list") === df2_altered("food"), "inner")
//groupBy() : acts as the opposite of explode() by grouping multiple rows together as one based on a column; with the specified mandatory Aggregate function(s)
    .groupBy(col("foods"))
    .agg(
      collect_list(col("category")) as "categories",
      max(col("has_fruit")) as "has fruit"
    )
    .show(false)
// +-------------------------+-----------------------------+---------+
// |foods                    |categories                   |has fruit|
// +-------------------------+-----------------------------+---------+
// |[Apple, Apple, Banana]   |[Fruit, Fruit, Fruit]        |true     |
// |[Apple, Carrot, Broccoli]|[Fruit, Vegetable, Vegetable]|true     |
// |[Spinach]                |[Vegetable]                  |false    |
// +-------------------------+-----------------------------+---------+

编辑:对于 First DataFrame 中存在的重复项,您可以使用您生成的 ID 列,然后对两列使用 groupBy

在下面提到了 3 行代码更改为 change1、change2 和 change3:

// change1
val df1_with_id = df1.withColumn("id", monotonically_increasing_id())
// change2
  val df1_altered = df1_with_id
    .withColumn("each_food_from_list", explode(col("foods")))

  // df1_altered.show()
  val df2_altered = df2
    .withColumn(
      "has_fruit",
      when(col("category").equalTo(lit("Fruit")), true).otherwise(false)
    )
  // df2_altered.show()

  df1_altered
    .join(df2_altered, df1_altered("each_food_from_list") === df2_altered("food"), "inner")
// change3
    .groupBy(col("id"),col("foods"))
    .agg(
      collect_list(col("category")) as "categories",
      max(col("has_fruit")) as "has fruit"
    )
    .show(false)


//+---+-------------------------+-----------------------------+---------+
//|id |foods                    |categories                   |has fruit|
//+---+-------------------------+-----------------------------+---------+
//|0  |[Apple, Apple, Banana]   |[Fruit, Fruit, Fruit]        |true     |
//|1  |[Apple, Apple, Banana]   |[Fruit, Fruit, Fruit]        |true     |
//|3  |[Spinach]                |[Vegetable]                  |false    |
//|2  |[Apple, Carrot, Broccoli]|[Fruit, Vegetable, Vegetable]|true     |
//+---+-------------------------+-----------------------------+---------+