Java、weka LibSVM 预测不正确
Posted
技术标签:
【中文标题】Java、weka LibSVM 预测不正确【英文标题】:Java, weka LibSVM does not predict correctly 【发布时间】:2017-10-12 09:31:52 【问题描述】:我在我的 java 代码中使用带有 weka 的 LibSVM。我正在尝试进行回归。下面是我的代码,
public static void predict()
try
DataSource sourcePref1 = new DataSource("train_pref2new.arff");
Instances trainData = sourcePref1.getDataSet();
DataSource sourcePref2 = new DataSource("testDatanew.arff");
Instances testData = sourcePref2.getDataSet();
if (trainData.classIndex() == -1)
trainData.setClassIndex(trainData.numAttributes() - 2);
if (testData.classIndex() == -1)
testData.setClassIndex(testData.numAttributes() - 2);
LibSVM svm1 = new LibSVM();
String options = ("-S 3 -K 2 -D 3 -G 1000.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.001 -P 0.1");
String[] optionsArray = options.split(" ");
svm1.setOptions(optionsArray);
svm1.buildClassifier(trainData);
for (int i = 0; i < testData.numInstances(); i++)
double pref1 = svm1.classifyInstance(testData.instance(i));
System.out.println("predicted value : " + pref1);
catch (Exception ex)
Logger.getLogger(Test.class.getName()).log(Level.SEVERE, null, ex);
但是我从这段代码得到的预测值与我使用 Weka GUI 得到的预测值不同。
示例: 下面是我为 java 代码和 weka GUI 提供的单个测试数据。
Java 代码预测值为 1.9064516129032265,而 Weka GUI 的预测值为 10.043。我对 Java 代码和 Weka GUI 使用相同的训练数据集和相同的参数。
希望你能理解我的问题。谁能告诉我我的代码有什么问题?
【问题讨论】:
【参考方案1】:您使用了错误的算法来执行 SVM 回归。 LibSVM 用于分类。你想要的是SMOreg,它是一个特定的回归支持向量机。
下面是一个完整的示例,展示了如何通过 Weka Explorer GUI 和 Java API 使用 SMOreg。对于数据,我将使用 Weka 发行版附带的 cpu.arff
数据文件。请注意,我会将此文件用于训练和测试,但理想情况下,您应该拥有单独的数据集。
使用 Weka Explorer GUI
-
打开 WEKA Explorer GUI,单击
Preprocess
选项卡,单击 Open File
,然后打开应该在 Weka 发行版中的 cpu.arff
文件。在我的系统上,该文件位于weka-3-8-1/data/cpu.arff
下。资源管理器窗口应如下所示:
-
单击
Classify
选项卡。它真的应该被称为“预测”,因为你可以在这里进行分类和回归。在Classifier
下,点击Choose
,然后选择weka
--> classifiers
--> functions
--> SMOreg
,如下图。
-
现在构建回归模型并对其进行评估。在
Test Options
下选择Use training set
以便我们的训练集也用于测试(正如我上面提到的,这不是理想的方法)。现在按Start
,结果应该如下所示:
记下 RMSE 值 (74.5996)。我们将在 Java 代码实现中重新讨论这一点。
使用 Java API
下面是一个完整的 Java 程序,它使用 Weka API 来复制前面在 Weka Explorer GUI 中显示的结果。
import weka.classifiers.functions.SMOreg;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
public class Tester
/**
* Builds a regression model using SMOreg, the SVM for regression, and
* evaluates it with the Evalution framework.
*/
public void buildAndEvaluate(String trainingArff, String testArff) throws Exception
System.out.printf("buildAndEvaluate() called.\n");
// Load the training and test instances.
Instances trainingInstances = DataSource.read(trainingArff);
Instances testInstances = DataSource.read(testArff);
// Set the true value to be the last field in each instance.
trainingInstances.setClassIndex(trainingInstances.numAttributes()-1);
testInstances.setClassIndex(testInstances.numAttributes()-1);
// Build the SMOregression model.
SMOreg smo = new SMOreg();
smo.buildClassifier(trainingInstances);
// Use Weka's evaluation framework.
Evaluation eval = new Evaluation(trainingInstances);
eval.evaluateModel(smo, testInstances);
// Print the options that were used in the ML algorithm.
String[] options = smo.getOptions();
System.out.printf("Options used:\n");
for (String option : options)
System.out.printf("%s ", option);
System.out.printf("\n\n");
// Print the algorithm details.
System.out.printf("Algorithm:\n %s\n", smo.toString());
// Print the evaluation results.
System.out.printf("%s\n", eval.toSummaryString("\nResults\n=====\n", false));
/**
* Builds a regression model using SMOreg, the SVM for regression, and
* tests each data instance individually to compute RMSE.
*/
public void buildAndTestEachInstance(String trainingArff, String testArff) throws Exception
System.out.printf("buildAndTestEachInstance() called.\n");
// Load the training and test instances.
Instances trainingInstances = DataSource.read(trainingArff);
Instances testInstances = DataSource.read(testArff);
// Set the true value to be the last field in each instance.
trainingInstances.setClassIndex(trainingInstances.numAttributes()-1);
testInstances.setClassIndex(testInstances.numAttributes()-1);
// Build the SMOregression model.
SMOreg smo = new SMOreg();
smo.buildClassifier(trainingInstances);
int numTestInstances = testInstances.numInstances();
// This variable accumulates the squared error from each test instance.
double sumOfSquaredError = 0.0;
// Loop over each test instance.
for (int i = 0; i < numTestInstances; i++)
Instance instance = testInstances.instance(i);
double trueValue = instance.value(testInstances.classIndex());
double predictedValue = smo.classifyInstance(instance);
// Uncomment the next line to see every prediction on the test instances.
//System.out.printf("true=%10.5f, predicted=%10.5f\n", trueValue, predictedValue);
double error = trueValue - predictedValue;
sumOfSquaredError += (error * error);
// Print the RMSE results.
double rmse = Math.sqrt(sumOfSquaredError / numTestInstances);
System.out.printf("RMSE = %10.5f\n", rmse);
public static void main(String argv[]) throws Exception
Tester classify = new Tester();
classify.buildAndEvaluate("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff");
classify.buildAndTestEachInstance("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff");
我编写了两个函数来训练 SMOreg 模型并通过对训练数据运行预测来评估模型。
buildAndEvaluate()
使用 Weka 评估模型
Evaluation
框架运行一套测试以获得完全相同
结果作为资源管理器 GUI。值得注意的是,它会产生一个 RMSE 值。
buildAndTestEachInstance()
通过显式评估模型
循环遍历每个测试实例,进行预测,计算
误差,并计算整体 RMSE。请注意,此 RMSE 匹配
来自buildAndEvaluate()
的那个,反过来匹配那个
从资源管理器 GUI。
下面是编译运行程序的结果。
prompt> javac -cp weka.jar Tester.java
prompt> java -cp .:weka.jar Tester
buildAndEvaluate() called.
Options used:
-C 1.0 -N 0 -I weka.classifiers.functions.supportVector.RegSMOImproved -T 0.001 -V -P 1.0E-12 -L 0.001 -W 1 -K weka.classifiers.functions.supportVector.PolyKernel -E 1.0 -C 250007
Algorithm:
SMOreg
weights (not support vectors):
+ 0.01 * (normalized) MYCT
+ 0.4321 * (normalized) MMIN
+ 0.1847 * (normalized) MMAX
+ 0.1175 * (normalized) CACH
+ 0.0973 * (normalized) CHMIN
+ 0.0235 * (normalized) CHMAX
- 0.0168
Number of kernel evaluations: 21945 (93.081% cached)
Results
=====
Correlation coefficient 0.9044
Mean absolute error 31.7392
Root mean squared error 74.5996
Relative absolute error 33.0908 %
Root relative squared error 46.4953 %
Total Number of Instances 209
buildAndTestEachInstance() called.
RMSE = 74.59964
【讨论】:
实际上 Libsvm 有 2 种回归的 svm 类型,nu-SVR 和 epsilon-SVR。我可以通过定义算法的 -S 参数来决定使用哪种 svm 类型。在我的代码中,我使用了 epsilon-SVR(-S 3)。但是您的代码确实帮助我找到了代码中的错误。 setClassIndex 在我的代码中是错误的。我使用了你的代码并且它有效。非常感谢您的帮助。以上是关于Java、weka LibSVM 预测不正确的主要内容,如果未能解决你的问题,请参考以下文章