如何使用 Spark org.apache.spark.ml.clustering.{KMeans, KMeansModel} 预测 kmeans 集群

How to predict kmeans cluster with Spark org.apache.spark.ml.clustering.{KMeans, KMeansModel}

我对两种不同的 MLLIB 实现(org.apache.spark.ml. 和 org.apache.spark.mllib)和 KMeans 有疑问。我正在使用 org.apache.spark.ml 的新实现,它使用数据帧,但我正在努力阅读文档以及如何预测聚类索引。

import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.{Row, SparkSession}

/**
  * An example showcasing the use of kMeans
  */
object ExploreKMeans {

  // Spark configuration.
  // Retrieve sparkContext with spark.sparkContext.
  private val spark = SparkSession.builder()
    .appName("com.example.ml.exploration.kMeans")
    .master("local[*]")
    .getOrCreate()

  // This import, after the definition of a valid SQLContext defines implicits for converting RDDs to Dataframes over .toDF().
  import spark.implicits._

  def main(args: Array[String]): Unit = {

    val data = spark.sparkContext.parallelize(Array((5.0, 2.0,1.5), (2.0, 2.5,2.3), (1.0, 2.1,4.2), (2.0, 5.5, 8.5)))

    val df = data.toDF().map { row =>
      val label = row(0).asInstanceOf[Double]
      val value1 = row(1).asInstanceOf[Double]
      val value2 = row(2).asInstanceOf[Double]

      LabeledPoint(label, Vectors.dense(value1,value2))
    }


    val kmeans = new KMeans().setK(3).setSeed(1L)
    val model: KMeansModel = kmeans.fit(df)

    // Evaluate clustering by computing Within Set Sum of Squared Errors.
    val WSSSE = model.computeCost(df)
    println(s"Within Set Sum of Squared Errors = $WSSSE")

    // Shows the result.
    println("Cluster Centers: ")
    model.clusterCenters.foreach(println)

    //TODO How to predict cluster index?
    //model.predict(???
  }
}

如何使用模型预测新值的聚类索引? model.predict 函数不可见。这API真是让人摸不着头脑...

好的,我明白了。现在使用转换方法完成预测:

  println("Transform ")
val transformed =  model.transform(df)
transformed.collect().foreach(println)

Cluster Centers: 
[2.25,1.9]
[5.5,8.5]
[2.1,4.2]
Transform: 
[5.0,[2.0,1.5],0]
[2.0,[2.5,2.3],0]
[1.0,[2.1,4.2],2]
[2.0,[5.5,8.5],1]

嗯,更简单的方法是:

model.summary.predictions.show