Spark-ML 编写自定义模型、转换器
Spark-ML writing custom Model, Transformer
这是 Spark 2.0.1
我正在尝试编译和使用 here 中的 SimpleIndexer
示例。
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
trait SimpleIndexerParams extends Params {
final val inputCol= new Param[String](this, "inputCol", "The input column")
final val outputCol = new Param[String](this, "outputCol", "The output column")
}
class SimpleIndexer(override val uid: String) extends Estimator[SimpleIndexerModel] with SimpleIndexerParams {
def setInputCol(value: String) = set(inputCol, value)
def setOutputCol(value: String) = set(outputCol, value)
def this() = this(Identifiable.randomUID("simpleindexer"))
override def copy(extra: ParamMap): SimpleIndexer = {
defaultCopy(extra)
}
override def transformSchema(schema: StructType): StructType = {
// Check that the input type is a string
val idx = schema.fieldIndex($(inputCol))
val field = schema.fields(idx)
if (field.dataType != StringType) {
throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
}
// Add the return field
schema.add(StructField($(outputCol), IntegerType, false))
}
override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
import dataset.sparkSession.implicits._
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
.collect()
new SimpleIndexerModel(uid, words)
; }
}
class SimpleIndexerModel(
override val uid: String, words: Array[String]) extends Model[SimpleIndexerModel] with SimpleIndexerParams {
override def copy(extra: ParamMap): SimpleIndexerModel = {
defaultCopy(extra)
}
private val labelToIndex: Map[String, Double] = words.zipWithIndex.
map{case (x, y) => (x, y.toDouble)}.toMap
override def transformSchema(schema: StructType): StructType = {
// Check that the input type is a string
val idx = schema.fieldIndex($(inputCol))
val field = schema.fields(idx)
if (field.dataType != StringType) {
throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
}
// Add the return field
schema.add(StructField($(outputCol), IntegerType, false))
}
override def transform(dataset: Dataset[_]): DataFrame = {
val indexer = udf { label: String => labelToIndex(label) }
dataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol)))
}
}
但是,我在转换过程中遇到错误:
val df = Seq(
(10, "hello"),
(20, "World"),
(30, "goodbye"),
(40, "sky")
).toDF("id", "phrase")
val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)
si.transform(df).show(false)
// java.util.NoSuchElementException: Failed to find a default value for inputCol
知道如何解决吗?
SimpleIndexer 转换方法似乎接受数据集作为参数——而不是数据帧(您传入的是数据帧)。
case class Phrase(id: Int, phrase:String)
si.transform(df.as[Phrase])....
有关详细信息,请参阅文档:https://spark.apache.org/docs/2.0.1/sql-programming-guide.html
编辑:
问题似乎是 SimpleIndexerModel 无法通过表达式 $(inputCol)
访问 "phrase" 列。我认为这是因为它在 SimpleIndexer class 中设置(并且上面的表达式工作正常)但在 SimpleIndexerModel.
中无法访问
一种解决方案是手动设置列名:
indexer(dataset.col("phrase").cast(StringType)).as("phrase_idx"))
但是在实例化 SimpleIndexerModel 时传入列名称可能会更好:
class SimpleIndexerModel(override val uid: String, words: Array[String], inputColName: String, outputColName: String)
....
new SimpleIndexerModel(uid, words, $(inputCol), $(outputCol))
结果:
+---+-------+----------+
|id |phrase |phrase_idx|
+---+-------+----------+
|10 |hello |1.0 |
|20 |World |0.0 |
|30 |goodbye|3.0 |
|40 |sky |2.0 |
+---+-------+----------+
好的,我通过查看 CountVectorizer
的源代码找到了答案。看来我需要用 copyValues(new SimpleIndexerModel(uid, words).setParent(this))
替换 new SimpleIndexerModel(uid, words)
。所以新的 fit
方法变成了
override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
import dataset.sparkSession.implicits._
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
.collect()
//new SimpleIndexerModel(uid, words)
copyValues(new SimpleIndexerModel(uid, words).setParent(this))
}
至此,参数被识别,转换很顺利。
val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)
si.explainParams
// res3: String =
// inputCol: The input column (current: phrase)
// outputCol: The output column (current: phrase_idx)
si.transform(df).show(false)
// +---+-------+----------+
// |id |phrase |phrase_idx|
// +---+-------+----------+
// |10 |hello |1.0 |
// |20 |World |0.0 |
// |30 |goodbye|3.0 |
// |40 |sky |2.0 |
// +---+-------+----------+
这是 Spark 2.0.1
我正在尝试编译和使用 here 中的 SimpleIndexer
示例。
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
trait SimpleIndexerParams extends Params {
final val inputCol= new Param[String](this, "inputCol", "The input column")
final val outputCol = new Param[String](this, "outputCol", "The output column")
}
class SimpleIndexer(override val uid: String) extends Estimator[SimpleIndexerModel] with SimpleIndexerParams {
def setInputCol(value: String) = set(inputCol, value)
def setOutputCol(value: String) = set(outputCol, value)
def this() = this(Identifiable.randomUID("simpleindexer"))
override def copy(extra: ParamMap): SimpleIndexer = {
defaultCopy(extra)
}
override def transformSchema(schema: StructType): StructType = {
// Check that the input type is a string
val idx = schema.fieldIndex($(inputCol))
val field = schema.fields(idx)
if (field.dataType != StringType) {
throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
}
// Add the return field
schema.add(StructField($(outputCol), IntegerType, false))
}
override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
import dataset.sparkSession.implicits._
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
.collect()
new SimpleIndexerModel(uid, words)
; }
}
class SimpleIndexerModel(
override val uid: String, words: Array[String]) extends Model[SimpleIndexerModel] with SimpleIndexerParams {
override def copy(extra: ParamMap): SimpleIndexerModel = {
defaultCopy(extra)
}
private val labelToIndex: Map[String, Double] = words.zipWithIndex.
map{case (x, y) => (x, y.toDouble)}.toMap
override def transformSchema(schema: StructType): StructType = {
// Check that the input type is a string
val idx = schema.fieldIndex($(inputCol))
val field = schema.fields(idx)
if (field.dataType != StringType) {
throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
}
// Add the return field
schema.add(StructField($(outputCol), IntegerType, false))
}
override def transform(dataset: Dataset[_]): DataFrame = {
val indexer = udf { label: String => labelToIndex(label) }
dataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol)))
}
}
但是,我在转换过程中遇到错误:
val df = Seq(
(10, "hello"),
(20, "World"),
(30, "goodbye"),
(40, "sky")
).toDF("id", "phrase")
val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)
si.transform(df).show(false)
// java.util.NoSuchElementException: Failed to find a default value for inputCol
知道如何解决吗?
SimpleIndexer 转换方法似乎接受数据集作为参数——而不是数据帧(您传入的是数据帧)。
case class Phrase(id: Int, phrase:String)
si.transform(df.as[Phrase])....
有关详细信息,请参阅文档:https://spark.apache.org/docs/2.0.1/sql-programming-guide.html
编辑:
问题似乎是 SimpleIndexerModel 无法通过表达式 $(inputCol)
访问 "phrase" 列。我认为这是因为它在 SimpleIndexer class 中设置(并且上面的表达式工作正常)但在 SimpleIndexerModel.
一种解决方案是手动设置列名:
indexer(dataset.col("phrase").cast(StringType)).as("phrase_idx"))
但是在实例化 SimpleIndexerModel 时传入列名称可能会更好:
class SimpleIndexerModel(override val uid: String, words: Array[String], inputColName: String, outputColName: String)
....
new SimpleIndexerModel(uid, words, $(inputCol), $(outputCol))
结果:
+---+-------+----------+
|id |phrase |phrase_idx|
+---+-------+----------+
|10 |hello |1.0 |
|20 |World |0.0 |
|30 |goodbye|3.0 |
|40 |sky |2.0 |
+---+-------+----------+
好的,我通过查看 CountVectorizer
的源代码找到了答案。看来我需要用 copyValues(new SimpleIndexerModel(uid, words).setParent(this))
替换 new SimpleIndexerModel(uid, words)
。所以新的 fit
方法变成了
override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
import dataset.sparkSession.implicits._
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
.collect()
//new SimpleIndexerModel(uid, words)
copyValues(new SimpleIndexerModel(uid, words).setParent(this))
}
至此,参数被识别,转换很顺利。
val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)
si.explainParams
// res3: String =
// inputCol: The input column (current: phrase)
// outputCol: The output column (current: phrase_idx)
si.transform(df).show(false)
// +---+-------+----------+
// |id |phrase |phrase_idx|
// +---+-------+----------+
// |10 |hello |1.0 |
// |20 |World |0.0 |
// |30 |goodbye|3.0 |
// |40 |sky |2.0 |
// +---+-------+----------+