使用 VectorAssembler 处理动态列
Dealing with dynamic columns with VectorAssembler
使用 sparks vector assembler 需要预先定义要组装的列。
但是,如果在前面的步骤将修改数据框的列的管道中使用矢量汇编程序,我如何指定列而不手动对所有值进行硬编码?
因为 df.columns
将 不 包含正确的值,当调用矢量汇编程序的构造函数时,目前我没有看到另一种处理方法或拆分管道 - 这也很糟糕,因为 CrossValidator 将不再正常工作。
val vectorAssembler = new VectorAssembler()
.setInputCols(df.columns
.filter(!_.contains("target"))
.filter(!_.contains("idNumber")))
.setOutputCol("features")
编辑
的初始 df
---+------+---+-
|foo| id|baz|
+---+------+---+
| 0| 1 | A|
| 1|2 | A|
| 0| 3 | null|
| 1| 4 | C|
+---+------+---+
将进行如下改造。您可以看到 nan 值将被估算为具有最常见的原始列和派生的一些特征,例如如此处所述 isA
如果 baz 为 A,则为 1,否则为 0,如果最初为 null N
+---+------+---+-------+
|foo|id |baz| isA |
+---+------+---+-------+
| 0| 1 | A| 1 |
| 1|2 | A|1 |
| 0| 3 | A| n |
| 1| 4 | C| 0 |
+---+------+---+-------+
稍后在管道中,使用 stringIndexer 使数据适合 ML/vectorAssembler。
isA
在原始 df 中不存在,但在 "only" 输出列中不存在此帧中除 foo 和 id 列之外的所有列都应由矢量汇编器转换。
我希望现在更清楚了。
如果我没看错你的问题,答案会非常简单明了,你只需要使用之前转换器的.getOutputCol
。
示例(来自官方文档):
// Prepare training documents from a list of (id, text, label) tuples.
val training = spark.createDataFrame(Seq(
(0L, "a b c d e spark", 1.0),
(1L, "b d", 0.0),
(2L, "spark f g h", 1.0),
(3L, "hadoop mapreduce", 0.0)
)).toDF("id", "text", "label")
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
val tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words")
val hashingTF = new HashingTF()
.setNumFeatures(1000)
.setInputCol(tokenizer.getOutputCol) // <==== Using the tokenizer output column
.setOutputCol("features")
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.001)
val pipeline = new Pipeline()
.setStages(Array(tokenizer, hashingTF, lr))
让我们现在将其应用于考虑另一个假设列的 VectorAssembler alpha
:
val assembler = new VectorAssembler()
.setInputCols(Array("alpha", tokenizer.getOutputCol)
.setOutputCol("features")
我创建了一个自定义矢量汇编器(原始的 1:1 副本),然后将其更改为包括所有列,除了一些要排除的列。
编辑
为了更清楚一点
def setInputColsExcept(value: Array[String]): this.type = set(inputCols, value)
指定应排除哪些列。然后
val remainingColumns = dataset.columns.filter(!$(inputCols).contains(_))
在转换方法中过滤所需的列。
使用 sparks vector assembler 需要预先定义要组装的列。
但是,如果在前面的步骤将修改数据框的列的管道中使用矢量汇编程序,我如何指定列而不手动对所有值进行硬编码?
因为 df.columns
将 不 包含正确的值,当调用矢量汇编程序的构造函数时,目前我没有看到另一种处理方法或拆分管道 - 这也很糟糕,因为 CrossValidator 将不再正常工作。
val vectorAssembler = new VectorAssembler()
.setInputCols(df.columns
.filter(!_.contains("target"))
.filter(!_.contains("idNumber")))
.setOutputCol("features")
编辑
的初始 df---+------+---+-
|foo| id|baz|
+---+------+---+
| 0| 1 | A|
| 1|2 | A|
| 0| 3 | null|
| 1| 4 | C|
+---+------+---+
将进行如下改造。您可以看到 nan 值将被估算为具有最常见的原始列和派生的一些特征,例如如此处所述 isA
如果 baz 为 A,则为 1,否则为 0,如果最初为 null N
+---+------+---+-------+
|foo|id |baz| isA |
+---+------+---+-------+
| 0| 1 | A| 1 |
| 1|2 | A|1 |
| 0| 3 | A| n |
| 1| 4 | C| 0 |
+---+------+---+-------+
稍后在管道中,使用 stringIndexer 使数据适合 ML/vectorAssembler。
isA
在原始 df 中不存在,但在 "only" 输出列中不存在此帧中除 foo 和 id 列之外的所有列都应由矢量汇编器转换。
我希望现在更清楚了。
如果我没看错你的问题,答案会非常简单明了,你只需要使用之前转换器的.getOutputCol
。
示例(来自官方文档):
// Prepare training documents from a list of (id, text, label) tuples.
val training = spark.createDataFrame(Seq(
(0L, "a b c d e spark", 1.0),
(1L, "b d", 0.0),
(2L, "spark f g h", 1.0),
(3L, "hadoop mapreduce", 0.0)
)).toDF("id", "text", "label")
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
val tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words")
val hashingTF = new HashingTF()
.setNumFeatures(1000)
.setInputCol(tokenizer.getOutputCol) // <==== Using the tokenizer output column
.setOutputCol("features")
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.001)
val pipeline = new Pipeline()
.setStages(Array(tokenizer, hashingTF, lr))
让我们现在将其应用于考虑另一个假设列的 VectorAssembler alpha
:
val assembler = new VectorAssembler()
.setInputCols(Array("alpha", tokenizer.getOutputCol)
.setOutputCol("features")
我创建了一个自定义矢量汇编器(原始的 1:1 副本),然后将其更改为包括所有列,除了一些要排除的列。
编辑
为了更清楚一点
def setInputColsExcept(value: Array[String]): this.type = set(inputCols, value)
指定应排除哪些列。然后
val remainingColumns = dataset.columns.filter(!$(inputCols).contains(_))
在转换方法中过滤所需的列。