在 Eclipse(weka) 中生成 ROC 曲线
Posted
技术标签:
【中文标题】在 Eclipse(weka) 中生成 ROC 曲线【英文标题】:Generate ROC curve in Eclipse(weka) 【发布时间】:2017-08-10 14:35:20 【问题描述】:我使用以下代码来训练/测试一些分类器:
public class WekaTest
public static BufferedReader readDataFile(String filename)
BufferedReader inputReader = null;
try
inputReader = new BufferedReader(new FileReader(filename));
catch (FileNotFoundException ex)
System.err.println("File not found: " + filename);
return inputReader;
public static Evaluation classify(Classifier model,
Instances trainingSet, Instances testingSet) throws Exception
Evaluation evaluation = new Evaluation(trainingSet);
model.buildClassifier(trainingSet);
evaluation.evaluateModel(model, testingSet);
return evaluation;
public static double calculateAccuracy(FastVector predictions)
double correct = 0;
for (int i = 0; i < predictions.size(); i++)
NominalPrediction np = (NominalPrediction) predictions.elementAt(i);
if (np.predicted() == np.actual())
correct++;
return 100 * correct / predictions.size();
public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds)
Instances[][] split = new Instances[2][numberOfFolds];
Random random = new Random();
for (int i = 0; i < numberOfFolds; i++)
split[0][i] = data.trainCV(numberOfFolds, i, random);
split[1][i] = data.testCV(numberOfFolds, i);
return split;
public static void main(String[] args) throws Exception
BufferedReader datafile = readDataFile("training_1.arff");
Instances data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
// Do 10-split cross validation
Instances[][] split = crossValidationSplit(data, 10);
// Separate split into training and testing arrays
Instances[] trainingSplits = split[0];
Instances[] testingSplits = split[1];
// Use a set of classifiers
Classifier[] models =
// new J48(), // a decision tree
// new PART(),
// new DecisionTable(),//decision table majority classifier
// new DecisionStump(), //one-level decision tree
new NaiveBayes(),
// new AdaBoostM1()
new RandomForest()
// new LMT()
;
// Run for each model
for (int j = 0; j < models.length; j++)
// Collect every group of predictions for current model in a FastVector
FastVector predictions = new FastVector();
// For each training-testing split pair, train and test the classifier
for (int i = 0; i < trainingSplits.length; i++)
Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]);
predictions.appendElements(validation.predictions());
System.out.println(validation.toMatrixString());
// Uncomment to see the summary for each training-testing pair.
// System.out.println(models[j].toString());
// generate curve
ThresholdCurve tc = new ThresholdCurve();
int classIndex = 0;
Instances result = tc.getCurve(validation.predictions(), classIndex);
System.out.println("tPR :"+validation.truePositiveRate(classIndex));
System.out.println("fNR :"+validation.falseNegativeRate(classIndex));
// plot curve
ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
vmc.setROCString("(Area under ROC = " +
Utils.doubleToString(tc.getROCArea(result), 4) + ")");
vmc.setName(result.relationName());
PlotData2D tempd = new PlotData2D(result);
tempd.setPlotName(result.relationName());
tempd.addInstanceNumberAttribute();
// specify which points are connected
boolean[] cp = new boolean[result.numInstances()];
for (int n = 1; n < cp.length; n++)
cp[n] = true;
tempd.setConnectPoints(cp);
// add plot
vmc.addPlot(tempd);
// display curve
String plotName = vmc.getName();
final javax.swing.JFrame jf =
new javax.swing.JFrame("Weka Classifier Visualize: "+plotName);
jf.setSize(500,400);
jf.getContentPane().setLayout(new BorderLayout());
jf.getContentPane().add(vmc, BorderLayout.CENTER);
jf.addWindowListener(new java.awt.event.WindowAdapter()
public void windowClosing(java.awt.event.WindowEvent e)
jf.dispose();
);
jf.setVisible(true);
// Calculate overall accuracy of current classifier on all splits
double accuracy = calculateAccuracy(predictions);
// Print current classifier's name and accuracy in a complicated,
// but nice-looking way.
System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": "
+ String.format("%.2f%%", accuracy)
+ "\n---------------------------------");
arff 文件包含描述,后跟 20 个数据属性,后跟“是”或“否”类标签。 在某些数据上运行此程序,可以准确计算出 TPR 和 FPR,并根据每个混淆矩阵进行显示;但是,ROC 曲线下的面积显示为“NaN”,曲线是一条垂直或水平直线:
我做错了什么?任何帮助将不胜感激。
【问题讨论】:
【参考方案1】:这将是一个评论,但我是新手,还不能评论。 我从循环内部运行了您的代码,并且使用我的数据,它就像一个魅力。 所以打印不是问题。
看起来你的评估有效
在一些数据上运行这个,可以准确计算出 TPR 和 FPR,并根据每个混淆矩阵显示出来;
那么你有没有尝试过Evaluation类的如下功能呢?
evaluation.areaUnderROC(int classIndex);
看看 ROC 曲线应该是什么样子?
您的标签类别是 "yes","no" 还是 0,1? 我不认为这是问题,但你可以试试
Instances result = tc.getCurve(validation.predictions());
而不是
Instances result = tc.getCurve(validation.predictions(), classIndex);
您能否发布一些混淆矩阵和 TPR/FPR 的值。
干杯
【讨论】:
以上是关于在 Eclipse(weka) 中生成 ROC 曲线的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Eclipse 中生成 Javadoc 注释? [复制]