Weka 如何使用 Java 代码预测新的看不见的实例?
Weka how to predict new unseen Instance using Java Code?
我写了一个 WEKA java 代码来训练 4 个分类器。我保存了分类器模型并希望使用它们来预测新的未见过的实例(将其想象成想要测试推文是正面还是负面的人)。
我对训练数据使用了 StringToWordsVector 过滤器。并且为了避免“Src 和 Dest 的属性数量不同”错误,我使用以下代码使用经过训练的数据训练过滤器,然后再将过滤器应用于新实例以尝试预测新实例是正面的还是负面的。我就是做对了。
Classifier cls = (Classifier) weka.core.SerializationHelper.read("models/myModel.model"); //reading one of the trained classifiers
BufferedReader datafile = readDataFile("Tweets/tone1.ARFF"); //read training data
Instances data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
Filter filter = new StringToWordVector(50);//keep 50 words
filter.setInputFormat(data);
Instances filteredData = Filter.useFilter(data, filter);
// rebuild classifier
cls.buildClassifier(filteredData);
String testInstance= "Text that I want to use as an unseen instance and predict whether it's positive or negative";
System.out.println(">create test instance");
FastVector attributes = new FastVector(2);
attributes.addElement(new Attribute("text", (FastVector) null));
// Add class attribute.
FastVector classValues = new FastVector(2);
classValues.addElement("Negative");
classValues.addElement("Positive");
attributes.addElement(new Attribute("Tone", classValues));
// Create dataset with initial capacity of 100, and set index of class.
Instances tests = new Instances("test istance", attributes, 100);
tests.setClassIndex(tests.numAttributes() - 1);
Instance test = new Instance(2);
// Set value for message attribute
Attribute messageAtt = tests.attribute("text");
test.setValue(messageAtt, messageAtt.addStringValue(testInstance));
test.setDataset(tests);
Filter filter2 = new StringToWordVector(50);
filter2.setInputFormat(tests);
Instances filteredTests = Filter.useFilter(tests, filter2);
System.out.println(">train Test filter using training data");
Standardize sfilter = new Standardize(); //Match the number of attributes between src and dest.
sfilter.setInputFormat(filteredData); // initializing the filter with training set
filteredTests = Filter.useFilter(filteredData, sfilter); // create new test set
ArffSaver saver = new ArffSaver(); //save test data to ARFF file
saver.setInstances(filteredTests);
File unseenFile = new File ("Tweets/unseen.ARFF");
saver.setFile(unseenFile);
saver.writeBatch();
当我尝试使用过滤后的训练数据对输入数据进行标准化时,我得到了一个新的 ARFF 文件 (unseen.ARFF),但有 2000 个(相同数量的训练数据)实例,其中大部分值为负数。我不明白为什么或如何删除这些实例。
System.out.println(">Evaluation"); //without the following 2 lines I get ArrayIndexOutOfBoundException.
filteredData.setClassIndex(filteredData.numAttributes() - 1);
filteredTests.setClassIndex(filteredTests.numAttributes() - 1);
Evaluation eval = new Evaluation(filteredData);
eval.evaluateModel(cls, filteredTests);
System.out.println(eval.toSummaryString("\nResults\n======\n", false));
打印评估结果我想查看例如此实例的正面或负面的百分比,但我得到以下结果。我还希望看到 1 个实例而不是 2000 个。任何有关如何执行此操作的帮助都将非常有用。
> Results
======
Correlation coefficient 0.0285
Mean absolute error 0.8765
Root mean squared error 1.2185
Relative absolute error 409.4123 %
Root relative squared error 121.8754 %
Total Number of Instances 2000
谢谢
使用eval.predictions()
。这是一个java.util.ArrayList<Prediction>
。然后你可以使用 Prediction.weight() 方法来得到你的测试变量是多少正数或负数....
cls.distributionForInstance(newInst)
returns 实例的概率分布。检查 docs
我已经找到了一个很好的解决方案,在这里我与您分享我的代码。这使用 WEKA Java 代码训练分类器,然后使用它来预测新的未见过的实例。有些部分(如路径)是硬编码的,但您可以轻松修改方法以获取参数。
/**
* This method performs classification of unseen instance.
* It starts by training a model using a selection of classifiers then classifiy new unlabled instances.
*/
public static void predict() throws Exception {
//start by providing the paths for your training and testing ARFF files make sure both files have the same structure and the exact classes in the header
//initialise classifier
Classifier classifier = null;
System.out.println("read training arff");
Instances train = new Instances(new BufferedReader(new FileReader("Train.arff")));
train.setClassIndex(0);//in my case the class was the first attribute thus zero otherwise it's the number of attributes -1
System.out.println("read testing arff");
Instances unlabeled = new Instances(new BufferedReader(new FileReader("Test.arff")));
unlabeled.setClassIndex(0);
// training using a collection of classifiers (NaiveBayes, SMO (AKA SVM), KNN and Decision trees.)
String[] algorithms = {"nb","smo","knn","j48"};
for(int w=0; w<algorithms.length;w++){
if(algorithms[w].equals("nb"))
classifier = new NaiveBayes();
if(algorithms[w].equals("smo"))
classifier = new SMO();
if(algorithms[w].equals("knn"))
classifier = new IBk();
if(algorithms[w].equals("j48"))
classifier = new J48();
System.out.println("==========================================================================");
System.out.println("training using " + algorithms[w] + " classifier");
Evaluation eval = new Evaluation(train);
//perform 10 fold cross validation
eval.crossValidateModel(classifier, train, 10, new Random(1));
String output = eval.toSummaryString();
System.out.println(output);
String classDetails = eval.toClassDetailsString();
System.out.println(classDetails);
classifier.buildClassifier(train);
}
Instances labeled = new Instances(unlabeled);
// label instances (use the trained classifier to classify new unseen instances)
for (int i = 0; i < unlabeled.numInstances(); i++) {
double clsLabel = classifier.classifyInstance(unlabeled.instance(i));
labeled.instance(i).setClassValue(clsLabel);
System.out.println(clsLabel + " -> " + unlabeled.classAttribute().value((int) clsLabel));
}
//save the model for future use
ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream("myModel.dat"));
out.writeObject(classifier);
out.close();
System.out.println("===== Saved model =====");
}
我写了一个 WEKA java 代码来训练 4 个分类器。我保存了分类器模型并希望使用它们来预测新的未见过的实例(将其想象成想要测试推文是正面还是负面的人)。
我对训练数据使用了 StringToWordsVector 过滤器。并且为了避免“Src 和 Dest 的属性数量不同”错误,我使用以下代码使用经过训练的数据训练过滤器,然后再将过滤器应用于新实例以尝试预测新实例是正面的还是负面的。我就是做对了。
Classifier cls = (Classifier) weka.core.SerializationHelper.read("models/myModel.model"); //reading one of the trained classifiers
BufferedReader datafile = readDataFile("Tweets/tone1.ARFF"); //read training data
Instances data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
Filter filter = new StringToWordVector(50);//keep 50 words
filter.setInputFormat(data);
Instances filteredData = Filter.useFilter(data, filter);
// rebuild classifier
cls.buildClassifier(filteredData);
String testInstance= "Text that I want to use as an unseen instance and predict whether it's positive or negative";
System.out.println(">create test instance");
FastVector attributes = new FastVector(2);
attributes.addElement(new Attribute("text", (FastVector) null));
// Add class attribute.
FastVector classValues = new FastVector(2);
classValues.addElement("Negative");
classValues.addElement("Positive");
attributes.addElement(new Attribute("Tone", classValues));
// Create dataset with initial capacity of 100, and set index of class.
Instances tests = new Instances("test istance", attributes, 100);
tests.setClassIndex(tests.numAttributes() - 1);
Instance test = new Instance(2);
// Set value for message attribute
Attribute messageAtt = tests.attribute("text");
test.setValue(messageAtt, messageAtt.addStringValue(testInstance));
test.setDataset(tests);
Filter filter2 = new StringToWordVector(50);
filter2.setInputFormat(tests);
Instances filteredTests = Filter.useFilter(tests, filter2);
System.out.println(">train Test filter using training data");
Standardize sfilter = new Standardize(); //Match the number of attributes between src and dest.
sfilter.setInputFormat(filteredData); // initializing the filter with training set
filteredTests = Filter.useFilter(filteredData, sfilter); // create new test set
ArffSaver saver = new ArffSaver(); //save test data to ARFF file
saver.setInstances(filteredTests);
File unseenFile = new File ("Tweets/unseen.ARFF");
saver.setFile(unseenFile);
saver.writeBatch();
当我尝试使用过滤后的训练数据对输入数据进行标准化时,我得到了一个新的 ARFF 文件 (unseen.ARFF),但有 2000 个(相同数量的训练数据)实例,其中大部分值为负数。我不明白为什么或如何删除这些实例。
System.out.println(">Evaluation"); //without the following 2 lines I get ArrayIndexOutOfBoundException.
filteredData.setClassIndex(filteredData.numAttributes() - 1);
filteredTests.setClassIndex(filteredTests.numAttributes() - 1);
Evaluation eval = new Evaluation(filteredData);
eval.evaluateModel(cls, filteredTests);
System.out.println(eval.toSummaryString("\nResults\n======\n", false));
打印评估结果我想查看例如此实例的正面或负面的百分比,但我得到以下结果。我还希望看到 1 个实例而不是 2000 个。任何有关如何执行此操作的帮助都将非常有用。
> Results
======
Correlation coefficient 0.0285
Mean absolute error 0.8765
Root mean squared error 1.2185
Relative absolute error 409.4123 %
Root relative squared error 121.8754 %
Total Number of Instances 2000
谢谢
使用eval.predictions()
。这是一个java.util.ArrayList<Prediction>
。然后你可以使用 Prediction.weight() 方法来得到你的测试变量是多少正数或负数....
cls.distributionForInstance(newInst)
returns 实例的概率分布。检查 docs
我已经找到了一个很好的解决方案,在这里我与您分享我的代码。这使用 WEKA Java 代码训练分类器,然后使用它来预测新的未见过的实例。有些部分(如路径)是硬编码的,但您可以轻松修改方法以获取参数。
/**
* This method performs classification of unseen instance.
* It starts by training a model using a selection of classifiers then classifiy new unlabled instances.
*/
public static void predict() throws Exception {
//start by providing the paths for your training and testing ARFF files make sure both files have the same structure and the exact classes in the header
//initialise classifier
Classifier classifier = null;
System.out.println("read training arff");
Instances train = new Instances(new BufferedReader(new FileReader("Train.arff")));
train.setClassIndex(0);//in my case the class was the first attribute thus zero otherwise it's the number of attributes -1
System.out.println("read testing arff");
Instances unlabeled = new Instances(new BufferedReader(new FileReader("Test.arff")));
unlabeled.setClassIndex(0);
// training using a collection of classifiers (NaiveBayes, SMO (AKA SVM), KNN and Decision trees.)
String[] algorithms = {"nb","smo","knn","j48"};
for(int w=0; w<algorithms.length;w++){
if(algorithms[w].equals("nb"))
classifier = new NaiveBayes();
if(algorithms[w].equals("smo"))
classifier = new SMO();
if(algorithms[w].equals("knn"))
classifier = new IBk();
if(algorithms[w].equals("j48"))
classifier = new J48();
System.out.println("==========================================================================");
System.out.println("training using " + algorithms[w] + " classifier");
Evaluation eval = new Evaluation(train);
//perform 10 fold cross validation
eval.crossValidateModel(classifier, train, 10, new Random(1));
String output = eval.toSummaryString();
System.out.println(output);
String classDetails = eval.toClassDetailsString();
System.out.println(classDetails);
classifier.buildClassifier(train);
}
Instances labeled = new Instances(unlabeled);
// label instances (use the trained classifier to classify new unseen instances)
for (int i = 0; i < unlabeled.numInstances(); i++) {
double clsLabel = classifier.classifyInstance(unlabeled.instance(i));
labeled.instance(i).setClassValue(clsLabel);
System.out.println(clsLabel + " -> " + unlabeled.classAttribute().value((int) clsLabel));
}
//save the model for future use
ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream("myModel.dat"));
out.writeObject(classifier);
out.close();
System.out.println("===== Saved model =====");
}