使用 Apache Spark 决策树分类器进行多类分类时出错

Posted

技术标签:

【中文标题】使用 Apache Spark 决策树分类器进行多类分类时出错【英文标题】:Error when using Apache Spark decision tree classifier for multiclass classification 【发布时间】:2017-04-09 10:09:46 【问题描述】:

我正在尝试根据从移动设备获取的传感器数据对用户活动进行分类。数据集包含用户 ID、传感器数据和活动。活动以整数形式给出,共有 12 类活动。下面给出的是我用于我的活动识别分类问题的代码。我正在使用 Apache Spark 决策树来解决多类分类问题。

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

public class DecisionTreeClass 
    public  static void main(String args[])
        SparkConf sparkConf = new SparkConf().setAppName("DecisionTreeClass").setMaster("local[2]");
        JavaSparkContext jsc = new JavaSparkContext(sparkConf);


        // Load and parse the data file.
        String datapath = "/home/thamali/Desktop/Project/csv/libsvm/trainlib.txt";
        JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
        // Split the data into training and test sets (30% held out for testing)
        JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]0.7, 0.3);
        JavaRDD<LabeledPoint> trainingData = splits[0];
        JavaRDD<LabeledPoint> testData = splits[1];

        // Set parameters.
        //  Empty categoricalFeaturesInfo indicates all features are continuous.
        Integer numClasses = 12;
        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap();
        String impurity = "gini";
        Integer maxDepth = 5;
        Integer maxBins = 32;

        // Train a DecisionTree model for classification.
        final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
                categoricalFeaturesInfo, impurity, maxDepth, maxBins);

        // Evaluate model on test instances and compute test error
        JavaPairRDD<Double, Double> predictionAndLabel =
                testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() 
                    @Override
                    public Tuple2<Double, Double> call(LabeledPoint p) 
                        return new Tuple2(model.predict(p.features()), p.label());
                    
                );
        Double testErr =
                1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() 
                    @Override
                    public Boolean call(Tuple2<Double, Double> pl) 
                        return !pl._1().equals(pl._2());
                    
                ).count() / testData.count();

        System.out.println("Test Error: " + testErr);
        System.out.println("Learned classification tree model:\n" + model.toDebugString());

        // Save and load model
        model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
        DecisionTreeModel sameModel = DecisionTreeModel
                .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
// $example off$
    


当使用上面的代码时,我得到了以下异常。有人可以帮我解决这个问题。

Caused by: java.lang.IllegalArgumentException: GiniAggregator given label 17.0 but requires label < numClasses (= 12).
    at org.apache.spark.mllib.tree.impurity.GiniAggregator.update(Gini.scala:92)
    at org.apache.spark.ml.tree.impl.DTStatsAggregator.update(DTStatsAggregator.scala:109)
    at org.apache.spark.ml.tree.impl.RandomForest$.orderedBinSeqOp(RandomForest.scala:326)
    at org.apache.spark.ml.tree.impl.RandomForest$.org$apache$spark$ml$tree$impl$RandomForest$$nodeBinSeqOp$1(RandomForest.scala:416)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$org$apache$spark$ml$tree$impl$RandomForest$$binSeqOp$1$1.apply(RandomForest.scala:441)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$org$apache$spark$ml$tree$impl$RandomForest$$binSeqOp$1$1.apply(RandomForest.scala:439)
    at scala.collection.immutable.Map$Map1.foreach(Map.scala:109)
    at org.apache.spark.ml.tree.impl.RandomForest$.org$apache$spark$ml$tree$impl$RandomForest$$binSeqOp$1(RandomForest.scala:439)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$9$$anonfun$apply$9.apply(RandomForest.scala:532)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$9$$anonfun$apply$9.apply(RandomForest.scala:532)
    at scala.collection.Iterator$class.foreach(Iterator.scala:727)
    at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$9.apply(RandomForest.scala:532)
    at org.apache.spark.ml.tree.impl.RandomForest$$anonfun$9.apply(RandomForest.scala:521)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:785)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:785)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:283)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:79)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:47)
    at org.apache.spark.scheduler.Task.run(Task.scala:86)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
    at java.lang.Thread.run(Thread.java:745)

【问题讨论】:

检查你的数据 - given label 17.0 but requires label &lt; numClasses (= 12) 所以在你的数据集中某处你有标签 17 非常感谢 【参考方案1】:

改为:

Integer numClasses = 17;

【讨论】:

以上是关于使用 Apache Spark 决策树分类器进行多类分类时出错的主要内容,如果未能解决你的问题,请参考以下文章

使用java的apache spark中的决策树实现问题

全面解析Apache Spark中的决策树

4.Spark ML学习笔记—Spark ML决策树 (应用案例)随机森林GBDT算法ML 树模型参数详解 (本篇概念多)

基于python的决策树能进行多分类吗

决策树分类器,多标签输出

如何在 spark ml 中处理决策树、随机森林的分类特征?