java 如何从 Spark MLLib Logistic 回归中获得置信度分数

How to get confidence scores from Spark MLLib Logistic Regression in java

更新:我尝试使用以下方式生成置信度分数,但它给了我一个例外。我使用以下代码片段:

double point = BLAS.dot(logisticregressionmodel.weights(), datavector);
double confScore = 1.0 / (1.0 + Math.exp(-point)); 

我得到的异常是:

Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18
    at scala.Predef$.require(Predef.scala:233)
    at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99)
    at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala)

你能帮忙吗?似乎权重向量的元素 (198) 多于数据向量(我正在生成 18 个特征)。它们在 dot() 函数中的长度必须相同。

我正在尝试在 Java 中实施一个程序,以使用 Spark MLLib 中可用的逻辑回归算法从现有数据集进行训练并预测新数据集( 1.5.0)。我的训练和预测程序如下,我正在使用多类实现。问题是当我执行 model.predict(vector)(注意预测程序中的 lrmodel.predict())时,我得到了预测标签。但是,如果我需要置信度分数怎么办?我怎么得到它?我已经浏览了 API 并且无法找到任何特定的 API 给出置信度分数。谁能帮帮我?

训练程序(生成 .model 文件)

public static void main(final String[] args) throws Exception {
        JavaSparkContext jsc = null;
        int salesIndex = 1;

        try {
           ...
       SparkConf sparkConf =
                    new SparkConf().setAppName("Hackathon Train").setMaster(
                            sparkMaster);
            jsc = new JavaSparkContext(sparkConf);
            ...

            JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache();

            final String firstRdd = trainRDD.first().trim();
            JavaRDD<String> tempRddFilter =
                    trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
                        private static final long serialVersionUID =
                                11111111111111111L;

                        public Boolean call(final String arg0) {
                            return !arg0.trim().equalsIgnoreCase(firstRdd);
                        }
                    });

           ...
            JavaRDD<String> featureRDD =
                    tempRddFilter
                            .map(new org.apache.spark.api.java.function.Function() {
                                private static final long serialVersionUID =
                                        6948900080648474074L;

                                public Object call(final Object arg0)
                                        throws Exception {
                                   ...
                                    StringBuilder featureSet =
                                            new StringBuilder();
                                   ...
                                        featureSet.append(i - 2);
                                        featureSet.append(COLON);
                                        featureSet.append(strVal);
                                        featureSet.append(SPACE);
                                    }

                                    return featureSet.toString().trim();
                                }
                            });

            List<String> featureList = featureRDD.collect();
            String featureOutput = StringUtils.join(featureList, NEW_LINE);
            String filePath = basePath + "lr.arff";
            FileUtils.writeStringToFile(new File(filePath), featureOutput,
                    "UTF-8");

            JavaRDD<LabeledPoint> trainingData =
                    MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache();

            final LogisticRegressionModel model =
                    new LogisticRegressionWithLBFGS().setNumClasses(18).run(
                            trainingData.rdd());
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(baos);
            oos.writeObject(model);
            oos.flush();
            oos.close();
            FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"),
                    baos.toByteArray());
            baos.close();

        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (jsc != null) {
                jsc.close();
            }
        }

预测程序(使用从训练程序生成的lr.model

    public static void main(final String[] args) throws Exception {
        JavaSparkContext jsc = null;
        int salesIndex = 1;
        try {
            ...
        SparkConf sparkConf =
                    new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster);
            jsc = new JavaSparkContext(sparkConf);

            ObjectInputStream objectInputStream =
                    new ObjectInputStream(new FileInputStream(basePath
                            + "lr.model"));
            LogisticRegressionModel lrmodel =
                    (LogisticRegressionModel) objectInputStream.readObject();
            objectInputStream.close();

            ...

            JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache();

            final String firstRdd = trainRDD.first().trim();
            JavaRDD<String> tempRddFilter =
                    trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
                        private static final long serialVersionUID =
                                11111111111111111L;

                        public Boolean call(final String arg0) {
                            return !arg0.trim().equalsIgnoreCase(firstRdd);
                        }
                    });

            ...
            final Broadcast<LogisticRegressionModel> broadcastModel =
                    jsc.broadcast(lrmodel);

            JavaRDD<String> featureRDD =
                    tempRddFilter
                            .map(new org.apache.spark.api.java.function.Function() {
                                private static final long serialVersionUID =
                                        6948900080648474074L;

                                public Object call(final Object arg0)
                                        throws Exception {
                                   ...
                                    LogisticRegressionModel lrModel =
                                            broadcastModel.value();
                                    String row = ((String) arg0);
                                    String[] featureSetArray =
                                            row.split(CSV_SPLITTER);
                                   ...
                                    final Vector vector =
                                            Vectors.dense(doubleArr);
                                    double score = lrModel.predict(vector);
                                   ...
                                    return csvString;
                                }
                            });

            String outputContent =
                    featureRDD
                            .reduce(new org.apache.spark.api.java.function.Function2() {

                                private static final long serialVersionUID =
                                        1212970144641935082L;

                                public Object call(Object arg0, Object arg1)
                                        throws Exception {
                                    ...
                                }

                            });
            ...
            FileUtils.writeStringToFile(new File(basePath
                    + "predicted-sales-data.csv"), sb.toString());
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (jsc != null) {
                jsc.close();
            }
        }
    }
}

确实好像不行。查看源代码,您可能可以将其扩展到 return 这些概率。

https://github.com/apache/spark/blob/branch-1.5/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

if (numClasses == 2) {
  val margin = dot(weightMatrix, dataMatrix) + intercept
  val score = 1.0 / (1.0 + math.exp(-margin))
  threshold match {
    case Some(t) => if (score > t) 1.0 else 0.0
    case None => score
  }

我希望它可以帮助您开始寻找解决方法。

经过多次尝试,我终于设法编写了一个自定义函数来生成置信度分数。它一点也不完美,但现在对我有用!

private static double getConfidenceScore(
            final LogisticRegressionModel lrModel, final Vector vector) {
        /* Approach to get confidence scores starts */
        Vector weights = lrModel.weights();
        int numClasses = lrModel.numClasses();
        int dataWithBiasSize = weights.size() / (numClasses - 1);
        boolean withBias = (vector.size() + 1) == dataWithBiasSize;
        double maxMargin = 0.0;
        double margin = 0.0;
        for (int j = 0; j < (numClasses - 1); j++) {
            margin = 0.0;
            for (int k = 0; k < vector.size(); k++) {
                double value = vector.toArray()[k];
                if (value != 0.0) {
                    margin += value
                            * weights.toArray()[(j * dataWithBiasSize) + k];
                }
            }
            if (withBias) {
                margin += weights.toArray()[(j * dataWithBiasSize)
                        + vector.size()];
            }
            if (margin > maxMargin) {
                maxMargin = margin;
            }
        }
        double conf = 1.0 / (1.0 + Math.exp(-maxMargin));
        DecimalFormat twoDForm = new DecimalFormat("#.##");
        double confidenceScore = Double.valueOf(twoDForm.format(conf * 100));
        /* Approach to get confidence scores ends */
        return confidenceScore;
    }