如何像一个热编码器一样使用 spark scala 将单个多分类列拆分为二进制?

How split single multiple categorical column into binary like one hot encoder use spark scala?

我的数据是这样的:

+---+---------+
| id|cate_list|
+---+---------+
|  0|  a,b,c,d|
|  1|    b,c,d|
|  2|      a,b|
|  3|        a|
|  4|a,b,c,d,e|
|  5|        e|
+---+---------+

我想要的是这样的:

-------------------------
| id|cate_list|a|b|c|d|e|
-------------------------
|  0|  a,b,c,d|1|1|1|1|0|
|  1|    b,c,d|0|1|1|1|0|
|  2|      a,b|1|1|0|0|0|
|  3|        a|1|0|0|0|0|
|  4|a,b,c,d,e|1|1|1|1|1|
|  5|        e|0|0|0|0|1|
-------------------------

我使用了 spark ML OneHotEncoder 并尝试了很多方法,最后我得到了这个:

+---+---------+-------------+-------------+
| id|cate_list|categoryIndex|  categoryVec|
+---+---------+-------------+-------------+
|  0|        a|          0.0|(4,[0],[1.0])|
|  0|        b|          1.0|(4,[1],[1.0])|
|  0|        c|          2.0|(4,[2],[1.0])|
|  0|        d|          3.0|(4,[3],[1.0])|
|  1|        b|          1.0|(4,[1],[1.0])|
|  1|        c|          2.0|(4,[2],[1.0])|
|  1|        d|          3.0|(4,[3],[1.0])|
|  2|        a|          0.0|(4,[0],[1.0])|
|  2|        b|          1.0|(4,[1],[1.0])|
|  3|        a|          0.0|(4,[0],[1.0])|
|  4|        a|          0.0|(4,[0],[1.0])|
|  4|        b|          1.0|(4,[1],[1.0])|
|  4|        c|          2.0|(4,[2],[1.0])|
|  4|        d|          3.0|(4,[3],[1.0])|
|  4|        e|          4.0|    (4,[],[])|
|  5|        e|          4.0|    (4,[],[])|
+---+---------+-------------+-------------+

这不是我需要的。当我使用 python 时,它真的很简单,几乎两行代码就可以解决这个问题。 Scala太难了

我的代码:

val df_split = df.withColumn("cate_list", explode(split($"cate_list", ",")))

val indexer = new StringIndexer()
  .setInputCol("cate_list")
  .setOutputCol("categoryIndex")
  .fit(df_split)
val indexed = indexer.transform(df_split)

val encoder = new OneHotEncoder()
  .setInputCol("categoryIndex")
  .setOutputCol("categoryVec")
val encoded = encoder.transform(indexed)

针对问题的初始数据的一种天真直接的方法。

我们应该有一个 udf 来计算目标单元格值,期望 cate_list 值和目标列名称:

val cateListContains = udf((cateList: String, item: String) => if (cateList.contains(item)) 1 else 0)

我们要提取一系列列名:

val targetColumns = Seq("a", "b", "c", "d", "e")

让我们 foldLeft 来源 DataFrame:

val resultDf = targetColumns.foldLeft(dfSrc) {
  case (df, item) => 
    df.withColumn(item, cateListContains($"cate_list", lit(item)))
}

它正好产生:

+---+---------+---+---+---+---+---+
|id |cate_list|a  |b  |c  |d  |e  |
+---+---------+---+---+---+---+---+
|0  |a,b,c,d  |1  |1  |1  |1  |0  |
|1  |b,c,d    |0  |1  |1  |1  |0  |
|2  |a,b      |1  |1  |0  |0  |0  |
|3  |a        |1  |0  |0  |0  |0  |
|4  |a,b,c,d,e|1  |1  |1  |1  |1  |
|5  |e        |0  |0  |0  |0  |1  |
+---+---------+---+---+---+---+---+

您可以使用 array_contains,其中 returns 一个布尔值,然后将其转换为 int

import org.apache.spark.sql.functions.array_contains

val aa = sc.parallelize(Array((0, "a,b,c,d"), (1, "b,c,d"), (2, "a, b"), (3, "a"), (4, "a,b,c,d,e"), (5, "e")))
var df = aa.toDF("id", "cate_list")   // create your data
val categories = Seq("a", "b", "c", "d", "e")
categories.foreach {col => 
  df = df.withColumn(col, array_contains(split($"cate_list", ","), col).cast("int"))
}
df.show()

结果:

+---+---------+---+---+---+---+---+
| id|cate_list|  a|  b|  c|  d|  e|
+---+---------+---+---+---+---+---+
|  0|  a,b,c,d|  1|  1|  1|  1|  0|
|  1|    b,c,d|  0|  1|  1|  1|  0|
|  2|     a, b|  1|  0|  0|  0|  0|
|  3|        a|  1|  0|  0|  0|  0|
|  4|a,b,c,d,e|  1|  1|  1|  1|  1|
|  5|        e|  0|  0|  0|  0|  1|
+---+---------+---+---+---+---+---+