Spark - 如何将 QuantileDiscretizer 与 RandomForestClassifier 一起使用

Spark - How to use QuantileDiscretizer with RandomForestClassifier

是否可以使用 QuantileDiscretizerkeeping NaN 值,RandomForestClassifier

我一直收到这样的错误:

18/03/23 17:38:15 ERROR Executor: Exception in task 3.0 in stage 133.0 (TID 381)
java.lang.IllegalArgumentException: DecisionTree given invalid data: Feature 1 is categorical with values in {0,...,1, but a data point gives it value 2.0.
  Bad data point: (1.0,[1.0,2.0])

例子

这里的想法是创建一个数字列并使用分位数将其离散化,将无效数字 (NaN) 保留在一个特殊的桶中。

import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler,
  QuantileDiscretizer}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassifier}

val tseq = Seq((0, "a", 1.0), (1, "b", 0.0), (2, "c", 2.0),
               (3, "a", 1.0), (4, "a", 3.0), (5, "c", Double.NaN))
val tdf = SparkInit.ss.createDataFrame(tseq).toDF("id", "category", "class")
val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("categoryIndex")
val discr = new QuantileDiscretizer()
  .setInputCol("class")
  .setOutputCol("quant")
  .setNumBuckets(2)
  .setHandleInvalid("keep")
val assembler = new VectorAssembler()
  .setInputCols(Array("categoryIndex", "quant"))
  .setOutputCol("features")
val rf = new RandomForestClassifier()
  .setLabelCol("categoryIndex")
  .setFeaturesCol("features")
  .setNumTrees(3)
new Pipeline()
  .setStages(Array(indexer, discr, assembler, rf))
  .fit(tdf)
  .transform(tdf)
  .show()

没有尝试适应随机森林,我得到了这样的 DataFrame:

+---+--------+-----+-------------+-----+---------+
| id|category|class|categoryIndex|quant| features|
+---+--------+-----+-------------+-----+---------+
|  0|       a|  1.0|          0.0|  1.0|[0.0,1.0]|
|  1|       b|  0.0|          2.0|  0.0|[2.0,0.0]|
|  2|       c|  2.0|          1.0|  1.0|[1.0,1.0]|
|  3|       a|  1.0|          0.0|  1.0|[0.0,1.0]|
|  4|       a|  3.0|          0.0|  1.0|[0.0,1.0]|
|  5|       c|  NaN|          1.0|  2.0|[1.0,2.0]|
+---+--------+-----+-------------+-----+---------+

如果我尝试拟合模型,我得到错误:

18/03/23 17:54:12 WARN DecisionTreeMetadata: DecisionTree reducing maxBins from 32 to 6 (= number of training instances)
18/03/23 17:54:12 WARN BlockManager: Putting block rdd_490_3 failed due to an exception
18/03/23 17:54:12 WARN BlockManager: Block rdd_490_3 could not be removed as it was not found on disk or in memory
18/03/23 17:54:12 ERROR Executor: Exception in task 3.0 in stage 143.0 (TID 414)
java.lang.IllegalArgumentException: DecisionTree given invalid data: Feature 1 is categorical with values in {0,...,1, but a data point gives it value 2.0.
  Bad data point: (1.0,[1.0,2.0])
    at org.apache.spark.ml.tree.impl.TreePoint$.findBin(TreePoint.scala:124)
    at org.apache.spark.ml.tree.impl.TreePoint$.org$apache$spark$ml$tree$impl$TreePoint$$labeledPointToTreePoint(TreePoint.scala:93)
    at org.apache.spark.ml.tree.impl.TreePoint$$anonfun$convertToTreeRDD.apply(TreePoint.scala:73)
    at org.apache.spark.ml.tree.impl.TreePoint$$anonfun$convertToTreeRDD.apply(TreePoint.scala:72)
    at scala.collection.Iterator$$anon.next(Iterator.scala:410)
    at scala.collection.Iterator$$anon.next(Iterator.scala:410)
    at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:216)

QuantileDiscretizer 是否插入某种关于特殊额外桶的元数据?奇怪的是,我之前能够使用具有相同值的列构建模型,但没有强制任何离散化。

更新

是的,列确实有附加的元数据,它看起来像这样:

org.apache.spark.sql.types.Metadata = {"ml_attr":
   {"ord":true,
    "vals":["-Infinity, 5.0","5.0, 10.0","10.0, Infinity"],
    "type":"nominal"}
}

现在的问题可能是:如何正确设置元数据以包含像 Double.NaN 这样的值?

我使用的解决方法是简单地从离散化列中删除关联的元数据,让决策树实现决定如何处理数据。我认为该列实际上会变成一个数字列(例如 [0, 1, 2, 2, 1]),但是,如果创建了太多类别,该列可能会再次离散化(查找参数 maxBins)。

就我而言,删除元数据的最简单方法是在应用 QuantileDiscretizer:

之后 fill DataFrame
// Nothing is actually filled in my case, since there was no missing
// values before this operation.
df.na.fill(Double.NaN, Array("quant"))

我几乎可以肯定您也可以手动删除直接访问列对象的元数据。

更新

我们可以通过创建别名 (reference) 来更改列的元数据:

val metadata: Metadata = ...
df.select($"colA".as("colB", metadata))

This answer 描述了一种通过获取 DataFrame 架构的相应 StructField 来获取列元数据的方法。