Spark 2逻辑回归删除阈值
Spark 2 logisticregression remove threshold
我正在使用 Spark 2 + Scala 训练基于 LogisticRegression 的二元分类模型,我正在使用 import org.apache.spark.ml.classification.LogisticRegression
,这是 Spark 2 中的新 ml API。但是,当我评估时AUROC 的模型,我没有找到使用概率(0-1 中的两倍)而不是二进制分类(0/1)的方法。这个以前是removeThreshold()
实现的,但是在ml.LogisticRegression
我没有找到类似的方法。那么,有没有办法做到这一点?
我使用的评估器是
val evaluator = new BinaryClassificationEvaluator()
.setLabelCol("label")
.setRawPredictionCol("rawPrediction")
.setMetricName("areaUnderROC")
val auroc = evaluator.evaluate(predictions)`
如果你想获得概率输出而不是 0/1 输出,试试这个:
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
val lr = new LogisticRegression()
.setMaxIter(100)
.setRegParam(0.3)
val lrModel = lr.fit(trainData)
val summary = lrModel.summary
summary.predictions.select("probability").show()
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary,
LogisticRegression}
val lr = new LogisticRegression().setMaxIter(100).setRegParam(0.3)
val lrModel = lr.fit(trainData)
val trainingSummary = lrModel.summary
val predictions = lrModel.transform(test)
predictions.select("label", "probability").show()
我正在使用 Spark 2 + Scala 训练基于 LogisticRegression 的二元分类模型,我正在使用 import org.apache.spark.ml.classification.LogisticRegression
,这是 Spark 2 中的新 ml API。但是,当我评估时AUROC 的模型,我没有找到使用概率(0-1 中的两倍)而不是二进制分类(0/1)的方法。这个以前是removeThreshold()
实现的,但是在ml.LogisticRegression
我没有找到类似的方法。那么,有没有办法做到这一点?
我使用的评估器是
val evaluator = new BinaryClassificationEvaluator()
.setLabelCol("label")
.setRawPredictionCol("rawPrediction")
.setMetricName("areaUnderROC")
val auroc = evaluator.evaluate(predictions)`
如果你想获得概率输出而不是 0/1 输出,试试这个:
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
val lr = new LogisticRegression()
.setMaxIter(100)
.setRegParam(0.3)
val lrModel = lr.fit(trainData)
val summary = lrModel.summary
summary.predictions.select("probability").show()
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary,
LogisticRegression}
val lr = new LogisticRegression().setMaxIter(100).setRegParam(0.3)
val lrModel = lr.fit(trainData)
val trainingSummary = lrModel.summary
val predictions = lrModel.transform(test)
predictions.select("label", "probability").show()