使用 Java WEKA 库时正确标记预测 类
Correctly labeling predicted classes when using the Java WEKA library
我有一个程序用 2-class 分类结果训练算法,然后 运行s 并写出预测(2 classes 中每一个的概率)对于未标记的数据集。
针对此程序的所有数据集 运行 将具有与结果相同的 2 classes。考虑到这一点,我 运行 进行了预测,并使用了一些 post-hoc 统计数据来确定结果的哪一列描述了哪个结果,并继续对其进行硬编码:
public class runPredictions {
public static void runPredictions(ArrayList al2) throws IOException, Exception{
// Retrieve objects
Instances newTest = (Instances) al2.get(0);
Classifier clf = (Classifier) al2.get(1);
// Print status
System.out.println("Generating predictions...");
// create copy
Instances labeled = new Instances(newTest);
BufferedWriter outFile = new BufferedWriter(new FileWriter("silverbullet_rro_output.csv"));
StringBuilder builder = new StringBuilder();
builder.append("Prob_Retain"+","+"Prob_Attrite"+"\n");
for (int i = 0; i < labeled.size(); i++)
{
double[] clsLabel = clf.distributionForInstance(newTest.instance(i));
for(int j=0;j<2;j++){
builder.append(clsLabel[j]+"");
if(j < clsLabel.length - 1)
builder.append(",");
}
builder.append("\n");
}
outFile.write(builder.toString());//save the string representation
System.out.println("Output file written.");
System.out.println("Completed successfully!");
outFile.close();
}
}
问题在于,结果是 2 列中的哪一列描述了 2 个结果类别中的哪一个是不固定的。这似乎与训练数据集中哪个类别首先出现有关,这完全是任意的。因此,当其他数据集与此程序一起使用时,硬编码标签是倒退的。
所以,我需要一种更好的方法来标记它们,但是查看 Classifier
和 distributionForInstance
的文档,我没有看到任何有用的东西。
更新:
我想出了如何将它打印到屏幕上(感谢 this),但在将它写入 csv 时仍然遇到问题:
for (int i = 0; i < labeled.size(); i++)
{
// Discreet prediction
double predictionIndex =
clf.classifyInstance(newTest.instance(i));
// Get the predicted class label from the predictionIndex.
String predictedClassLabel =
newTest.classAttribute().value((int) predictionIndex);
// Get the prediction probability distribution.
double[] predictionDistribution =
clf.distributionForInstance(newTest.instance(i));
// Print out the true predicted label, and the distribution
System.out.printf("%5d: predicted=%-10s, distribution=",
i, predictedClassLabel);
// Loop over all the prediction labels in the distribution.
for (int predictionDistributionIndex = 0;
predictionDistributionIndex < predictionDistribution.length;
predictionDistributionIndex++)
{
// Get this distribution index's class label.
String predictionDistributionIndexAsClassLabel =
newTest.classAttribute().value(
predictionDistributionIndex);
// Get the probability.
double predictionProbability =
predictionDistribution[predictionDistributionIndex];
System.out.printf("[%10s : %6.3f]",
predictionDistributionIndexAsClassLabel,
predictionProbability );
// Attempt to write to CSV
builder.append(i+","+predictedClassLabel+","+
predictionDistributionIndexAsClassLabel+","+predictionProbability);
//.charAt(0)+','+predictionProbability.charAt(0));
}
System.out.printf("\n");
builder.append("\n");
我从这个 answer and this answer 改编了下面的代码。基本上,您可以查询 class 属性的测试数据,然后获取每个可能的 class.
的具体值
for (int i = 0; i < labeled.size(); i++)
{
// Discreet prediction
double predictionIndex =
clf.classifyInstance(newTest.instance(i));
// Get the predicted class label from the predictionIndex.
String predictedClassLabel =
newTest.classAttribute().value((int) predictionIndex);
// Get the prediction probability distribution.
double[] predictionDistribution =
clf.distributionForInstance(newTest.instance(i));
// Print out the true predicted label, and the distribution
System.out.printf("%5d: predicted=%-10s, distribution=",
i, predictedClassLabel);
// Loop over all the prediction labels in the distribution.
for (int predictionDistributionIndex = 0;
predictionDistributionIndex < predictionDistribution.length;
predictionDistributionIndex++)
{
// Get this distribution index's class label.
String predictionDistributionIndexAsClassLabel =
newTest.classAttribute().value(
predictionDistributionIndex);
// Get the probability.
double predictionProbability =
predictionDistribution[predictionDistributionIndex];
System.out.printf("[%10s : %6.3f]",
predictionDistributionIndexAsClassLabel,
predictionProbability );
// Write to CSV
builder.append(i+","+
predictionDistributionIndexAsClassLabel+","+predictionProbability);
}
System.out.printf("\n");
builder.append("\n");
}
// Save results in .csv file
outFile.write(builder.toString());//save the string representation
我有一个程序用 2-class 分类结果训练算法,然后 运行s 并写出预测(2 classes 中每一个的概率)对于未标记的数据集。
针对此程序的所有数据集 运行 将具有与结果相同的 2 classes。考虑到这一点,我 运行 进行了预测,并使用了一些 post-hoc 统计数据来确定结果的哪一列描述了哪个结果,并继续对其进行硬编码:
public class runPredictions {
public static void runPredictions(ArrayList al2) throws IOException, Exception{
// Retrieve objects
Instances newTest = (Instances) al2.get(0);
Classifier clf = (Classifier) al2.get(1);
// Print status
System.out.println("Generating predictions...");
// create copy
Instances labeled = new Instances(newTest);
BufferedWriter outFile = new BufferedWriter(new FileWriter("silverbullet_rro_output.csv"));
StringBuilder builder = new StringBuilder();
builder.append("Prob_Retain"+","+"Prob_Attrite"+"\n");
for (int i = 0; i < labeled.size(); i++)
{
double[] clsLabel = clf.distributionForInstance(newTest.instance(i));
for(int j=0;j<2;j++){
builder.append(clsLabel[j]+"");
if(j < clsLabel.length - 1)
builder.append(",");
}
builder.append("\n");
}
outFile.write(builder.toString());//save the string representation
System.out.println("Output file written.");
System.out.println("Completed successfully!");
outFile.close();
}
}
问题在于,结果是 2 列中的哪一列描述了 2 个结果类别中的哪一个是不固定的。这似乎与训练数据集中哪个类别首先出现有关,这完全是任意的。因此,当其他数据集与此程序一起使用时,硬编码标签是倒退的。
所以,我需要一种更好的方法来标记它们,但是查看 Classifier
和 distributionForInstance
的文档,我没有看到任何有用的东西。
更新:
我想出了如何将它打印到屏幕上(感谢 this),但在将它写入 csv 时仍然遇到问题:
for (int i = 0; i < labeled.size(); i++)
{
// Discreet prediction
double predictionIndex =
clf.classifyInstance(newTest.instance(i));
// Get the predicted class label from the predictionIndex.
String predictedClassLabel =
newTest.classAttribute().value((int) predictionIndex);
// Get the prediction probability distribution.
double[] predictionDistribution =
clf.distributionForInstance(newTest.instance(i));
// Print out the true predicted label, and the distribution
System.out.printf("%5d: predicted=%-10s, distribution=",
i, predictedClassLabel);
// Loop over all the prediction labels in the distribution.
for (int predictionDistributionIndex = 0;
predictionDistributionIndex < predictionDistribution.length;
predictionDistributionIndex++)
{
// Get this distribution index's class label.
String predictionDistributionIndexAsClassLabel =
newTest.classAttribute().value(
predictionDistributionIndex);
// Get the probability.
double predictionProbability =
predictionDistribution[predictionDistributionIndex];
System.out.printf("[%10s : %6.3f]",
predictionDistributionIndexAsClassLabel,
predictionProbability );
// Attempt to write to CSV
builder.append(i+","+predictedClassLabel+","+
predictionDistributionIndexAsClassLabel+","+predictionProbability);
//.charAt(0)+','+predictionProbability.charAt(0));
}
System.out.printf("\n");
builder.append("\n");
我从这个 answer and this answer 改编了下面的代码。基本上,您可以查询 class 属性的测试数据,然后获取每个可能的 class.
的具体值for (int i = 0; i < labeled.size(); i++)
{
// Discreet prediction
double predictionIndex =
clf.classifyInstance(newTest.instance(i));
// Get the predicted class label from the predictionIndex.
String predictedClassLabel =
newTest.classAttribute().value((int) predictionIndex);
// Get the prediction probability distribution.
double[] predictionDistribution =
clf.distributionForInstance(newTest.instance(i));
// Print out the true predicted label, and the distribution
System.out.printf("%5d: predicted=%-10s, distribution=",
i, predictedClassLabel);
// Loop over all the prediction labels in the distribution.
for (int predictionDistributionIndex = 0;
predictionDistributionIndex < predictionDistribution.length;
predictionDistributionIndex++)
{
// Get this distribution index's class label.
String predictionDistributionIndexAsClassLabel =
newTest.classAttribute().value(
predictionDistributionIndex);
// Get the probability.
double predictionProbability =
predictionDistribution[predictionDistributionIndex];
System.out.printf("[%10s : %6.3f]",
predictionDistributionIndexAsClassLabel,
predictionProbability );
// Write to CSV
builder.append(i+","+
predictionDistributionIndexAsClassLabel+","+predictionProbability);
}
System.out.printf("\n");
builder.append("\n");
}
// Save results in .csv file
outFile.write(builder.toString());//save the string representation