在 Java 中使用随机森林打印实际和预测的类标签

Posted

技术标签:

【中文标题】在 Java 中使用随机森林打印实际和预测的类标签【英文标题】:Print actual and predicted class labels using Random Forest in Java 【发布时间】:2017-11-02 05:02:57 【问题描述】:

我有一个包含 10000 条记录的大型数据集,其中 5000 条属于 1 类,其余 5000 条属于 -1 类。我使用了随机森林,获得了超过 90% 的良好准确率。

如果我有一个 arff 文件

@relation cds_orf
@attribute start numeric
@attribute end numeric
@attribute score numeric
@attribute orf_coverage numeric
@attribute class 1,-1
@data
(suppose this contains 5 records)

我的输出应该是这样的

 No   Actual_class   Predicted class
 1     1                   1
 2     1                   1   
 3    -1                  -1  
 4     1                   -1
 5     1                    1

我希望 Java 代码打印此输出。谢谢。 (注意:我使用了 classifier.classifyInstance() 但它给出了 NullPointerException)

【问题讨论】:

【参考方案1】:

好吧,经过大量研究,我自己找到了答案。以下代码执行相同操作并将输出写入另一个文件 orf_out。

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.Random;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.RandomForest;  
import weka.core.Instances;
 
/**
 *
 * @author samy
 */
public class WekaTest 
 
    /**
     * @throws java.lang.Exception
     */
    public static void rfnew() throws Exception 
        BufferedReader br;
        int numFolds = 10;
        br = new BufferedReader(new FileReader("orf_arff"));
 
        Instances trainData = new Instances(br);
        trainData.setClassIndex(trainData.numAttributes() - 1);
        br.close();
        
        RandomForest rf = new RandomForest();
        rf.setNumTrees(100);         
     
        Evaluation evaluation = new Evaluation(trainData);
        evaluation.crossValidateModel(rf, trainData, numFolds, new Random(1));
        rf.buildClassifier(trainData);
        PrintWriter out = new PrintWriter("orf_out");
        out.println("No.\tTrue\tPredicted");
        for (int i = 0; i < trainData.numInstances(); i++)      
        
            String trueClassLabel;
            trueClassLabel = trainData.instance(i).toString(trainData.classIndex());
             // Discreet prediction
            double predictionIndex = 
            rf.classifyInstance(trainData.instance(i)); 

            // Get the predicted class label from the predictionIndex.
            String predictedClassLabel;            
            predictedClassLabel = trainData.classAttribute().value((int) predictionIndex);
            out.println((i+1)+"\t"+trueClassLabel+"\t"+predictedClassLabel);
        
        
        out.println(evaluation.toSummaryString("\nResults\n======\n", true));
        out.println(evaluation.toClassDetailsString());
        out.println("Results For Class -1- ");
        out.println("Precision=  " + evaluation.precision(0));
        out.println("Recall=  " + evaluation.recall(0));
        out.println("F-measure=  " + evaluation.fMeasure(0));
        out.println("Results For Class -2- ");
        out.println("Precision=  " + evaluation.precision(1));
        out.println("Recall=  " + evaluation.recall(1));
        out.println("F-measure=  " + evaluation.fMeasure(1)); 
        out.close();
    

我需要在我的代码中使用 buildClassifier。

【讨论】:

以上是关于在 Java 中使用随机森林打印实际和预测的类标签的主要内容,如果未能解决你的问题,请参考以下文章

r中随机森林的类重要性

python基于随机森林模型的预测概率和标签信息可视化ROC曲线

如何在python中用随机森林预测多个类?

在 R 中查看随机森林的预测值与实际值

从随机森林模型中提取树的子集进行预测

TensorFlow 中的神经网络比随机森林效果更差,并且每次都预测相同的标签