如何使用 Weka API 在 J48 / C4.5 上进行 10 倍交叉验证后保存最佳树
Posted
技术标签:
【中文标题】如何使用 Weka API 在 J48 / C4.5 上进行 10 倍交叉验证后保存最佳树【英文标题】:how to save the best tree after 10 fold cross validation on J48 / C4.5 using Weka API 【发布时间】:2014-11-01 20:59:31 【问题描述】:希望每个人都身体健康。
我想对一组数据进行 10 折交叉验证,使用 J48 作为分类器。所以加载了数据,然后我想使用 10 折交叉验证创建训练和测试集,在生成 10 棵树后,我想以最佳分类精度保存树并将其用作规则库。
到目前为止,我的程序加载了数据,进行了 10 次交叉验证,然后我厌倦了保存分类器。但是我不确定我是否走在正确的轨道上,因为我想保存 10 棵树中最好的树,并保存规则。
我该怎么做?
我还要在哪里创建测试集?我的程序是否正确地进行了拆分?
感谢您的任何建议和帮助。
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.Random;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.J48;
import weka.core.Instances;
import weka.core.Range;
public class Unpruned
public static void main(String [] args)
try
BufferedReader bReader = readDataFile("weather.arrf");
Instances train = new Instances(bReader);
train.setClassIndex(train.numAttributes() -1); //last attribute is the class attribute
J48 myTree = new J48();
myTree.setUnpruned(true);
Evaluation eval=new Evaluation(train);
//first supply the classifier
//then the training data
//number of folds
//random seed
eval.crossValidateModel(myTree, train, 10, new Random(1));
System.out.println("Percent correct: "+
Double.toString(eval.pctCorrect()));
myTree.buildClassifier(train);
System.out.print(myTree.graph());
weka.core.SerializationHelper.write("D:/myTree.model", myTree);
catch(Exception e)
e.printStackTrace();
public static BufferedReader readDataFile(String filename)
BufferedReader inputReader = null;
try
inputReader = new BufferedReader(new FileReader(filename));
catch (FileNotFoundException ex)
System.err.println("File not found: " + filename);
return inputReader;
更新:
@马修斯宾塞
好的,我已经将我的程序更改为这个知道,所以在输出中我得到了 10 棵树,它为我提供了每棵树的规则。我将如何只保存一个,例如折叠 9。
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.NominalPrediction;
import weka.classifiers.rules.DecisionTable;
import weka.classifiers.rules.PART;
import weka.classifiers.trees.DecisionStump;
import weka.classifiers.trees.J48;
import weka.core.FastVector;
import weka.core.Instances;
public class WekaTest
public static BufferedReader readDataFile(String filename)
BufferedReader inputReader = null;
try
inputReader = new BufferedReader(new FileReader(filename));
catch (FileNotFoundException ex)
System.err.println("File not found: " + filename);
return inputReader;
public static Evaluation classify(Classifier model,
Instances trainingSet, Instances testingSet) throws Exception
Evaluation evaluation = new Evaluation(trainingSet);
model.buildClassifier(trainingSet);
evaluation.evaluateModel(model, testingSet);
return evaluation;
public static double calculateAccuracy(FastVector predictions)
double correct = 0;
for (int i = 0; i < predictions.size(); i++)
NominalPrediction np = (NominalPrediction) predictions.elementAt(i);
if (np.predicted() == np.actual())
correct++;
return 100 * correct / predictions.size();
public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds)
Instances[][] split = new Instances[2][numberOfFolds];
for (int i = 0; i < numberOfFolds; i++)
split[0][i] = data.trainCV(numberOfFolds, i);
split[1][i] = data.testCV(numberOfFolds, i);
return split;
public static void main(String[] args) throws Exception
BufferedReader datafile = readDataFile("pima_diabetes.arrf");
Instances data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
// Do 10-split cross validation
Instances[][] split = crossValidationSplit(data, 10);
// Separate split into training and testing arrays
Instances[] trainingSplits = split[0];
Instances[] testingSplits = split[1];
// Use a set of classifiers
Classifier[] models =
new J48(), // a decision tree
new PART(),
new DecisionTable(),//decision table majority classifier
new DecisionStump() //one-level decision tree
;
// Run for each model
for (int j = 0; j < models.length; j++)
// Collect every group of predictions for current model in a FastVector
FastVector predictions = new FastVector();
// For each training-testing split pair, train and test the classifier
for (int i = 0; i < trainingSplits.length; i++)
Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]);
predictions.appendElements(validation.predictions());
// Uncomment to see the summary for each training-testing pair.
System.out.println(models[j].toString());
// Calculate overall accuracy of current classifier on all splits
double accuracy = calculateAccuracy(predictions);
// Print current classifier's name and accuracy in a complicated,
// but nice-looking way.
System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": "
+ String.format("%.2f%%", accuracy)
+ "\n---------------------------------" + models[1].toString() );
更新 2:
我的输出:
Accuracy of J48: 74.87%
---------------------------------J48 pruned tree
------------------
plas <= 127: tested_negative (437.0/85.0)
plas > 127
| mass <= 29.9: tested_negative (70.0/20.0)
| mass > 29.9
| | pres <= 61: tested_positive (22.0)
| | pres > 61
| | | plas <= 157
| | | | age <= 30
| | | | | preg <= 0
| | | | | | pres <= 68: tested_positive (4.0)
| | | | | | pres > 68
| | | | | | | insu <= 135
| | | | | | | | mass <= 35.5: tested_negative (3.0)
| | | | | | | | mass > 35.5: tested_positive (4.0/1.0)
| | | | | | | insu > 135: tested_negative (2.0)
| | | | | preg > 0
| | | | | | preg <= 2: tested_negative (11.0)
| | | | | | preg > 2
| | | | | | | pedi <= 0.332: tested_negative (6.0)
| | | | | | | pedi > 0.332
| | | | | | | | plas <= 144: tested_positive (4.0)
| | | | | | | | plas > 144: tested_negative (3.0)
| | | | age > 30: tested_positive (52.0/15.0)
| | | plas > 157: tested_positive (74.0/11.0)
Number of Leaves : 13
Size of the tree : 25
Accuracy of PART: 72.40%
---------------------------------J48 pruned tree
------------------
plas <= 127: tested_negative (437.0/85.0)
plas > 127
| mass <= 29.9: tested_negative (70.0/20.0)
| mass > 29.9
| | pres <= 61: tested_positive (22.0)
| | pres > 61
| | | plas <= 157
| | | | age <= 30
| | | | | preg <= 0
| | | | | | pres <= 68: tested_positive (4.0)
| | | | | | pres > 68
| | | | | | | insu <= 135
| | | | | | | | mass <= 35.5: tested_negative (3.0)
| | | | | | | | mass > 35.5: tested_positive (4.0/1.0)
| | | | | | | insu > 135: tested_negative (2.0)
| | | | | preg > 0
| | | | | | preg <= 2: tested_negative (11.0)
| | | | | | preg > 2
| | | | | | | pedi <= 0.332: tested_negative (6.0)
| | | | | | | pedi > 0.332
| | | | | | | | plas <= 144: tested_positive (4.0)
| | | | | | | | plas > 144: tested_negative (3.0)
| | | | age > 30: tested_positive (52.0/15.0)
| | | plas > 157: tested_positive (74.0/11.0)
Number of Leaves : 13
Size of the tree : 25
Accuracy of DecisionTable: 73.96%
---------------------------------J48 pruned tree
------------------
plas <= 127: tested_negative (437.0/85.0)
plas > 127
| mass <= 29.9: tested_negative (70.0/20.0)
| mass > 29.9
| | pres <= 61: tested_positive (22.0)
| | pres > 61
| | | plas <= 157
| | | | age <= 30
| | | | | preg <= 0
| | | | | | pres <= 68: tested_positive (4.0)
| | | | | | pres > 68
| | | | | | | insu <= 135
| | | | | | | | mass <= 35.5: tested_negative (3.0)
| | | | | | | | mass > 35.5: tested_positive (4.0/1.0)
| | | | | | | insu > 135: tested_negative (2.0)
| | | | | preg > 0
| | | | | | preg <= 2: tested_negative (11.0)
| | | | | | preg > 2
| | | | | | | pedi <= 0.332: tested_negative (6.0)
| | | | | | | pedi > 0.332
| | | | | | | | plas <= 144: tested_positive (4.0)
| | | | | | | | plas > 144: tested_negative (3.0)
| | | | age > 30: tested_positive (52.0/15.0)
| | | plas > 157: tested_positive (74.0/11.0)
Number of Leaves : 13
Size of the tree : 25
Accuracy of DecisionStump: 72.01%
---------------------------------J48 pruned tree
------------------
plas <= 127: tested_negative (437.0/85.0)
plas > 127
| mass <= 29.9: tested_negative (70.0/20.0)
| mass > 29.9
| | pres <= 61: tested_positive (22.0)
| | pres > 61
| | | plas <= 157
| | | | age <= 30
| | | | | preg <= 0
| | | | | | pres <= 68: tested_positive (4.0)
| | | | | | pres > 68
| | | | | | | insu <= 135
| | | | | | | | mass <= 35.5: tested_negative (3.0)
| | | | | | | | mass > 35.5: tested_positive (4.0/1.0)
| | | | | | | insu > 135: tested_negative (2.0)
| | | | | preg > 0
| | | | | | preg <= 2: tested_negative (11.0)
| | | | | | preg > 2
| | | | | | | pedi <= 0.332: tested_negative (6.0)
| | | | | | | pedi > 0.332
| | | | | | | | plas <= 144: tested_positive (4.0)
| | | | | | | | plas > 144: tested_negative (3.0)
| | | | age > 30: tested_positive (52.0/15.0)
| | | plas > 157: tested_positive (74.0/11.0)
Number of Leaves : 13
Size of the tree : 25
【问题讨论】:
【参考方案1】:希望你也好!
我查看了here 并找不到任何明确的迹象表明可以访问评估类中的各个分类器。我还对 J48 使用 10 倍交叉验证运行了一个测试样本,似乎也无法明确访问各个分类器。
如果您需要从十个中找出最好的,也许您可以通过编程方式生成 10 个折叠并使用您在上面应用的训练和评估算法,然后保存最终给出最佳结果的分类器。
希望这会有所帮助!
更新!
基于上述更改,请在下面找到一种选择和保存最佳分类器的可能解决方案:
// Run for each model
Classifier bestClassifier = models[0];
double bestAccuracy = -1;
for (int j = 0; j < models.length; j++)
// Collect every group of predictions for current model in a FastVector
FastVector predictions = new FastVector();
// For each training-testing split pair, train and test the classifier
for (int i = 0; i < trainingSplits.length; i++)
Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]);
predictions.appendElements(validation.predictions());
// Calculate overall accuracy of current classifier on all splits
double accuracy = calculateAccuracy(predictions);
if (accuracy > bestAccuracy)
bestClassifier = models[j];
bestAccuracy = accuracy;
// Print current classifier's name and accuracy in a complicated,
// but nice-looking way.
System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": "
+ String.format("%.2f%%", accuracy));
// Save the Best Classifier
weka.core.SerializationHelper.write("best.model", bestClassifier);
【讨论】:
您好,感谢您的回复。本质上我想做的是从十个折叠中找到最好的折叠,然后保存该树的规则,以便以后创建一个模糊决策树。我不太明白您所说的以编程方式生成 10 折是什么意思,您介意再解释一下吗?问候 我当时的想法是,如果你不能从每个折叠中得到一个分类器,也许你可以生成十个不同的折叠,然后单独生成分类器,然后保存最准确的分类器从那里开始。 请看我的第一篇文章,我已经在那里更新了,谢谢 我根据您的修改添加了一个潜在的解决方案。 非常感谢@matthew。我还在 println 中添加了一个 bestClassifier.toString(),这样我就可以看到每个分类器的输出,所以我把我的输出放在我的第一篇文章的上面,所以现在我保存这些规则并用它们构建一棵树我可以以某种方式保存它在一个数组或什么的?如果有办法如何知道规则的开始和结束?解析是可能的还是我必须硬编码?问候以上是关于如何使用 Weka API 在 J48 / C4.5 上进行 10 倍交叉验证后保存最佳树的主要内容,如果未能解决你的问题,请参考以下文章