减少两个 Scala 方法,它们只在一种对象类型上不同

Reduce two Scala methods, that only differ in one Object Type

我有以下两种方法,使用 Apache Spark 中的对象。

  def SVMModelScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = {
    val model = SVMModel.load(sc, modelFileName)

    val scoreAndLabels = 
      MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
        val score = model.predict(point.features)
        (score, point.label)
      }
    return scoreAndLabels
  }

  def DecisionTreeScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = {
    val model = DecisionTreeModel.load(sc, modelFileName)

    val scoreAndLabels = 
      MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
        val score = model.predict(point.features)
        (score, point.label)
      }
    return scoreAndLabels
  }

我之前合并这些函数的尝试导致错误围绕 model.predict。

有没有一种方法可以将模型用作 Scala 中弱类型的参数?

免责声明 - 我从未使用过 Apache Spark。

在我看来,这两种方法之间的唯一区别是实例化 model 的方式。不幸的是,这两个 model 实例实际上并没有共享提供 predict(...) 的共同特征,但我们仍然可以通过提取更改的部分来完成这项工作 - scorer:

def scoreWith(sc: SparkContext, scoringDataset: String)(scorer: (Vector)=>Double): RDD[(Double, Double)] = {
  MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
    val score = scorer(point.features)
    (score, point.label)
  }
}

现在我们可以获得以前的功能:

def svmScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) =
  scoreWith(sc: SparkContext, scoringDataset:String)(SVMModel.load(sc, modelFileName).predict)

def dtScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) =
  scoreWith(sc: SparkContext, scoringDataset:String)(DecisionTreeModel.load(sc, modelFileName).predict)