决策树的几种类型差异及Spark 2.0-MLlibScikit代码分析
Posted 千寻千梦
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了决策树的几种类型差异及Spark 2.0-MLlibScikit代码分析相关的知识,希望对你有一定的参考价值。
概述
- 分类决策树模型是一种描述对实例进行分类的树形结构。 决策树可以看为一个if-then规则集合,具有“互斥完备”性质 。决策树基本上都是 采用的是贪心(即非回溯)的算法,自顶向下递归分治构造。
- 生成决策树一般包含三个步骤:
- 特征选择
- 决策树生成
- 剪枝
决策树算法种类
- 决策树主要有 ID3, C4.5, C5.0 and CART几种, ID3, C4.5, 和CART实际都采用的是贪心(即非回溯)的算法,自顶向下递归分治构造。对于每一个决策要求分成的组之间的“差异”最大。各种决策树算法之间的主要区别就是对这个“差异”衡量方式的区别。
- ID3和CART方法大约同时独立发明,形成了决策树归纳研究的基础。
ID3请参考:http://blog.csdn.net/acdreamers/article/details/44661149
在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息 增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍 历可能的决策空间。在信息增益中,重要性的衡量标准就是看特征能够为分类系统带来多少信息,带来的信息越多,该特征越 重要。在认识信息增益之前,先来看看信息熵的定义(每一个类别的概率是P(Ci))
熵这个概念最早起源于物理学,在物理学中是用来度量一个热力学系统的无序程度,而在信息学里面,熵是对不确定性的度量。在1948年,香农引入了信息熵,将其定义为离散随机事件出现的概率,一个系统越是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序化程度的一个度量。
C4.5:是在ID3决策树的基础之上进行改进,C4.5克服了ID3的2个缺点:
(1)用信息增益选择属性时偏向于选择分枝比较多的属性值,即取值多的属性
(2)不能处理连贯属性
C4.5是这样做的:
(1)选取能够得到最大信息增益率(information gain ratio)的特征来划分数据,并且像ID3一样执行后剪枝。也就是说采用比率,能较好解决ID3的第(1)个缺点
(2)当特征数值连续时,在分类的时候进行离散化。C5.0是Quinlan最新发布版本的决策树算法,需要专利授权。相对于C 4.5而言,该方法计算时占用内存更少,建立了更小的规则集,计算结果也更加准确。C5.0算法由于执行效率和内存使用改进、适用大数据集。C5.0算法选择分支变量的依据:以信息熵的下降速度作为确定最佳分支变量和分割阀值的依据。信息熵的下降意味着信息的不确定性下降。
(1)ID3 (Iterative Dichotomiser 3) was developed in 1986 by Ross Quinlan. The algorithm creates a multiway tree, finding for each node (i.e. in a greedy manner) the categorical feature that will yield the largest information gain for categorical targets. Trees are grown to their maximum size and then a pruning step is usually applied to improve the ability of the tree to generalise to unseen data. (2)C4.5 is the successor to ID3 and removed the restriction that features must be categorical by dynamically defining a discrete attribute (based on numerical variables) that partitions the continuous attribute value into a discrete set of intervals. C4.5 converts the trained trees (i.e. the output of the ID3 algorithm) into sets of if-then rules. These accuracy of each rule is then evaluated to determine the order in which they should be applied. Pruning is done by removing a rule’s precondition if the accuracy of the rule improves without it. (3)C5.0 is Quinlan’s latest version release under a proprietary license. It uses less memory and builds smaller rulesets than C4.5 while being more accurate. (4)CART (Classification and Regression Trees) is very similar to C4.5, but it differs in that it supports numerical target variables (regression) and does not compute rule sets. CART constructs binary trees using the feature and threshold that yield the largest information gain at each node.
CART(分类回归树),非常类似于C4.5。它使用基尼不纯度(Gini Impurity)来决定划分。差别请参考:http://blog.csdn.net/lingtianyulong/article/details/34522757
它和C45基本上是类似的算法,主要区别:1)它的叶节点不是具体的分类,而是是一个函数f(),该函数定义了在该条件下的回归函数。2)CART是二叉树,而不是多叉树。
Scikit中的决策树
参考:http://blog.csdn.net/sandyzhs/article/details/46814805
- Scikit中使用的是优化版本的CART算法。 scikit-learn uses an optimised version of the CART algorithm.
Spark MLlib决策树代码详细分析
- Spark MLlib中决策树 Spark2.0,基于DataFrame的API。
- 决策树和决策树的组合,是解决分类问题和回归问题比较流行的一类算法。具备了诸多优点:
- 结果易于解释;
- 可以处理类别特征;
- 可以扩展到多分类;
- 不需要对特征进行归一化;
- 可以分析各feature之间的相互作用。
随机森林,boosting算法都是决策树的组合。
Spark中决策树可以解决二分类,多分类和回归问题,可以使用连续的和类别的特征。由于数据集是按行进行分区的,可以对大型数据集(百万甚至十亿级的数据集)进行分布式训练。
(1)Decision trees and their ensembles are popular methods for the machine learning tasks of classification and regression. Decision trees are widely used since they are easy to interpret, handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble algorithms such as random forests and boosting are among the top performers for classification and regression tasks.
(2)The spark.ml implementation supports decision trees for binary and multiclass classification and for regression, using both continuous and categorical features. The implementation partitions data by rows, allowing distributed training with millions or even billions of instances.
在Spark Pipeline中编写决策树流程,以下三个特征转换器(Feature Transformers)常用,但不是特别好理解(结合本部分后面各自例子看吧)。
http://spark.apache.org/docs/latest/ml-features.html#vectorindexer
VectorIndexer
主要作用:提高决策树或随机森林等ML方法的分类效果。
VectorIndexer是对数据集特征向量中的类别(离散值)特征(index categorical features categorical features )进行编号。
它能够自动判断那些特征是离散值型的特征,并对他们进行编号,具体做法是通过设置一个maxCategories,特征向量中某一个特征不重复取值个数小于maxCategories,则被重新编号为0~K(K<=maxCategories-1)。某一个特征不重复取值个数大于maxCategories,则该特征视为连续值,不会重新编号(不会发生任何改变)。结合例子看吧,实在太绕了。
VectorIndexer helps index categorical features in datasets of Vectors. It can both automatically decide which features are categorical and convert original values to category indices. Specifically, it does the following:
Take an input column of type Vector and a parameter maxCategories. Decide which features should be categorical based on the number of distinct values, where features with at most maxCategories are declared categorical.
Compute 0-based category indices for each categorical feature.
Index categorical features and transform original feature values to indices.
Indexing categorical features allows algorithms such as Decision Trees and Tree Ensembles to treat categorical features appropriately, improving performance.
This transformed data could then be passed to algorithms such as DecisionTreeRegressor that handle categorical features.
用一个简单的数据集举例如下:
//定义输入输出列和最大类别数为5,某一个特征
//(即某一列)中多于5个取值视为连续值
VectorIndexerModel featureIndexerModel=new VectorIndexer()
.setInputCol("features")
.setMaxCategories(5)
.setOutputCol("indexedFeatures")
.fit(rawData);
//加入到Pipeline
Pipeline pipeline=new Pipeline()
.setStages(new PipelineStage[]
{labelIndexerModel,
featureIndexerModel,
dtClassifier,
converter});
pipeline.fit(rawData).transform(rawData).select("features","indexedFeatures").show(20,false);
//显示如下的结果:
+-------------------------+-------------------------+
|features |indexedFeatures |
+-------------------------+-------------------------+
|(3,[0,1,2],[2.0,5.0,7.0])|(3,[0,1,2],[2.0,1.0,1.0])|
|(3,[0,1,2],[3.0,5.0,9.0])|(3,[0,1,2],[3.0,1.0,2.0])|
|(3,[0,1,2],[4.0,7.0,9.0])|(3,[0,1,2],[4.0,3.0,2.0])|
|(3,[0,1,2],[2.0,4.0,9.0])|(3,[0,1,2],[2.0,0.0,2.0])|
|(3,[0,1,2],[9.0,5.0,7.0])|(3,[0,1,2],[9.0,1.0,1.0])|
|(3,[0,1,2],[2.0,5.0,9.0])|(3,[0,1,2],[2.0,1.0,2.0])|
|(3,[0,1,2],[3.0,4.0,9.0])|(3,[0,1,2],[3.0,0.0,2.0])|
|(3,[0,1,2],[8.0,4.0,9.0])|(3,[0,1,2],[8.0,0.0,2.0])|
|(3,[0,1,2],[3.0,6.0,2.0])|(3,[0,1,2],[3.0,2.0,0.0])|
|(3,[0,1,2],[5.0,9.0,2.0])|(3,[0,1,2],[5.0,4.0,0.0])|
+-------------------------+-------------------------+
结果分析:特征向量包含3个特征,即特征0,特征1,特征2。如Row=1,对应的特征分别是2.0,5.0,7.0.被转换为2.0,1.0,1.0。
我们发现只有特征1,特征2被转换了,特征0没有被转换。这是因为特征0有6中取值(2,3,4,5,8,9),多于前面的设置setMaxCategories(5)
,因此被视为连续值了,不会被转换。
特征1中,(4,5,6,7,9)-->(0,1,2,3,4,5)
特征2中, (2,7,9)-->(0,1,2)
输出DataFrame格式说明(Row=1):
3个特征 特征0,1,2 转换前的值
|(3, [0,1,2], [2.0,5.0,7.0])
3个特征 特征1,1,2 转换后的值
|(3, [0,1,2], [2.0,1.0,1.0])|
StringIndexer
理解了前面的VectorIndexer之后,StringIndexer对数据集的label进行重新编号就很容易理解了,都是采用类似的转换思路,看下面的例子就可以了。//定义一个StringIndexerModel,将label转换成indexedlabel
StringIndexerModel labelIndexerModel=new StringIndexer().
setInputCol("label")
.setOutputCol("indexedLabel")
.fit(rawData);
//加labelIndexerModel加入到Pipeline中
Pipeline pipeline=new Pipeline()
.setStages(new PipelineStage[]
{labelIndexerModel,
featureIndexerModel,
dtClassifier,
converter});
//查看结果
pipeline.fit(rawData).transform(rawData).select("label","indexedLabel").show(20,false);
按label出现的频次,转换成0~num numOfLabels-1(分类个数),频次最高的转换为0,以此类推:
label=3,出现次数最多,出现了4次,转换(编号)为0
其次是label=2,出现了3次,编号为1,以此类推
+-----+------------+
|label|indexedLabel|
+-----+------------+
|3.0 |0.0 |
|4.0 |3.0 |
|1.0 |2.0 |
|3.0 |0.0 |
|2.0 |1.0 |
|3.0 |0.0 |
|2.0 |1.0 |
|3.0 |0.0 |
|2.0 |1.0 |
|1.0 |2.0 |
+-----+------------+
在其它地方应用StringIndexer时还需要注意两个问题: (1)StringIndexer本质上是对String类型–>index( number);如果是:数值(numeric)–>index(number),实际上是对把数值先进行了类型转换( cast numeric to string and then index the string values.),也就是说无论是String,还是数值,都可以重新编号(Index); (2)利用获得的模型转化新数据集时,可能遇到异常情况,见下面例子。
StringIndexer对String按频次进行编号
id | category | categoryIndex
----|----------|---------------
0 | a | 0.0
1 | b | 2.0
2 | c | 1.0
3 | a | 0.0
4 | a | 0.0
5 | c | 1.0
如果转换模型(关系)是基于上面数据得到的 (a,b,c)->(0.0,2.0,1.0),如果用此模型转换category多于(a,b,c)的数据,比如多了d,e,就会遇到麻烦:
id | category | categoryIndex
----|----------|---------------
0 | a | 0.0
1 | b | 2.0
2 | d | ?
3 | e | ?
4 | a | 0.0
5 | c | 1.0
Spark提供了两种处理方式:
StringIndexerModel labelIndexerModel=new StringIndexer().
setInputCol("label")
.setOutputCol("indexedLabel")
//.setHandleInvalid("error")
.setHandleInvalid("skip")
.fit(rawData);
(1)默认设置,也就是.setHandleInvalid("error"):会抛出异常
org.apache.spark.SparkException: Unseen label: d,e
(2).setHandleInvalid("skip") 忽略这些label所在行的数据,正常运行,将输出如下结果:
id | category | categoryIndex
----|----------|---------------
0 | a | 0.0
1 | b | 2.0
4 | a | 0.0
5 | c | 1.0
IndexToString
相应的,有StringIndexer,就应该有IndexToString。在应用StringIndexer对labels进行重新编号后,带着这些编号后的label对数据进行了训练,并接着对其他数据进行了预测,得到预测结果,预测结果的label也是重新编号过的,因此需要转换回来。见下面例子,转换回来的convetedPrediction才和原始的label对应。 Symmetrically to StringIndexer, IndexToString maps a column of label indices back to a column containing the original labels as strings. A common use case is to produce indices from labels with StringIndexer, train a model with those indices and retrieve the original labels from the column of predicted indices with IndexToString.
IndexToString converter=new IndexToString()
.setInputCol("prediction")//Spark默认预测label行
.setOutputCol("convetedPrediction")//转换回来的预测label
.setLabels(labelIndexerModel.labels());//需要指定前面建好相互相互模型
Pipeline pipeline=new Pipeline()
.setStages(new PipelineStage[]
{labelIndexerModel,
featureIndexerModel,
dtClassifier,
converter});
pipeline.fit(rawData).transform(rawData)
.select("label","prediction","convetedPrediction").show(20,false);
|label|prediction|convetedPrediction|
+-----+----------+------------------+
|3.0 |0.0 |3.0 |
|4.0 |1.0 |2.0 |
|1.0 |2.0 |1.0 |
|3.0 |0.0 |3.0 |
|2.0 |1.0 |2.0 |
|3.0 |0.0 |3.0 |
|2.0 |1.0 |2.0 |
|3.0 |0.0 |3.0 |
|2.0 |1.0 |2.0 |
|1.0 |2.0 |1.0 |
+-----+----------+------------------+
Spark MLlib中树剪枝方法与决策树参数设置
**剪枝的参数设置:** 先(预)剪枝方法,通过提前停止树的构建,而对树进行“剪枝”: 通过设置如下的条件,进行剪枝:- maxDepth:限定决策树的最大可能深度。但由于其它终止条件或者是被剪枝的缘故,最终的决策树的深度可能要比maxDepth小。
- minInfoGain:最小信息增益(设置阈值),小于该值将不带继续分叉;
minInstancesPerNode:如果某个节点的样本数量小于该值,则该节点将不再被分叉。(设置阈值)
实际上要想获得一个适当的阈值是相当困难的。高阈值可能导致过分简化的树,而低阈值可能简化不够。预剪枝方法minInfoGain、minInstancesPerNode实际上是通过不断修改停止条件来得到合理的结果,这并不是一个好办法,事实上,我们常常甚至不知道要寻找什么样的结果。这样就需要对树进行后剪枝了(后剪枝不需要用户指定参数,是更为理想化的剪枝方法)
- Spark MLLib中用了后剪枝方法没有?目前我还没研究明白。
- 当然后剪枝方法也不总是比先剪枝方法更有效。为了寻找最佳的模型,更合理的做法是:同时使用这两种剪技术。
DecisionTreeClassifier dtClassifier=new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setMaxDepth(maxDepth)
//.setMinInfoGain(0.5)
//.setMinInstancesPerNode(10)
//.setImpurity("gini")//Gini不纯度
.setImpurity("entropy")//或者熵
The node depth is equal to the maxDepth training parameter.
No split candidate leads to an information gain greater than minInfoGain.
No split candidate produces child nodes which each have at least minInstancesPerNode training instances.
节点不纯度和信息增益方法设置:
分类问题可设置:
.setImpurity(“gini”)//Gini不纯度
.setImpurity(“entropy”)//或者熵
分类结果评估
(1)手工探索:可以简单设置一个循环,对关键参数MaxDepth,两种不纯度不同组合计算准确度,accuracy。
(2)利用CrossValiator交叉验证方法,可参考本人另一篇文章:
Spark2.0基于Pipeline、交叉验证、ParamMap的模型选择和超参数调优
http://blog.csdn.net/qq_34531825/article/details/52334436
(3)两类的分类问题的除accuracy外的其他评价方法(指标),可参考本人另外一篇文章:
Logistic回归参数设置,分类结果评估(Spark2.0、Python Scikit)
http://blog.csdn.net/qq_34531825/article/details/52313553
Spark2.0决策树分类问题完整代码
这里写代码片
package my.spark.ml.practice;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class myDecisionTreeClassifer {
public static void main(String[] args) {
SparkSession spark=SparkSession
.builder()
.master("local[4]")
.appName("myDecisonTreeClassifer")
.getOrCreate();
//屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
//-------------------0 加载数据------------------------------------------
String path="/home/hadoop/spark/spark-2.0.0-bin-hadoop2.6" +
"/data/mllib/sample_multiclass_classification_data.txt";
//"/data/mllib/sample_libsvm_data.txt";
Dataset<Row> rawData=spark.read().format("libsvm").load(path);
Dataset<Row>[] split=rawData.randomSplit(new double[]{0.8,0.2});
Dataset<Row> training=split[0];
Dataset<Row> test=split[1];
//rawData.show(100);//加载数据检查,显示100行数,每一行都不截断
//-------------------1 建立决策树训练的Pipeline-------------------------------
//1.1 对label进行重新编号
StringIndexerModel labelIndexerModel=new StringIndexer().
setInputCol("label")
.setOutputCol("indexedLabel")
//.setHandleInvalid("error")
.setHandleInvalid("skip")
.fit(rawData);
//1.2 对特征向量进行重新编号
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 5 distinct values are
//treated as continuous.
//针对离散型特征而言的,对离散型特征值进行编号。
//.setMaxCategories(5)表示假如特征值的取值多于四种,则视为连续值
//也就是这样设置就无效了
VectorIndexerModel featureIndexerModel=new VectorIndexer()
.setInputCol("features")
.setMaxCategories(5)
.setOutputCol("indexedFeatures")
.fit(rawData);
//1.3 决策树分类器
/*DecisionTreeClassifier dtClassifier=
new DecisionTreeClassifier()
.setLabelCol("indexedLabel")//使用index后的label
.setFeaturesCol("indexedFeatures");//使用index后的features
*/
//1.3 决策树分类器参数设置
for(int maxDepth=2;maxDepth<10;maxDepth++){
DecisionTreeClassifier dtClassifier=new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setMaxDepth(maxDepth)
//.setMinInfoGain(0.5)
//.setMinInstancesPerNode(10)
//.setImpurity("gini")//Gini不纯度
.setImpurity("entropy")//或者熵
//.setMaxBins(100)//其它可调试的还有一些参数
;
//1.4 将编号后的预测label转换回来
IndexToString converter=new IndexToString()
.setInputCol("prediction")//自动产生的预测label行名字
.setOutputCol("convetedPrediction")
.setLabels(labelIndexerModel.labels());
//Pileline这四个阶段,
Pipeline pipeline=new Pipeline()
.setStages(new PipelineStage[]
{labelIndexerModel,
featureIndexerModel,
dtClassifier,
converter});
//在训练集上训练pipeline模型
PipelineModel pipelineModel=pipeline.fit(training);
//-----------------------------3 多分类结果评估----------------------------
//预测
Dataset<Row> testPrediction=pipelineModel.transform(test);
MulticlassClassificationEvaluator evaluator=
new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy");
//评估
System.out.println("MaxDepth is: "+maxDepth);
double accuracy= evaluator.evaluate(testPrediction);
System.out.println("accuracy is: "+accuracy);
//输出决策树模型
DecisionTreeClassificationModel treeModel =
(DecisionTreeClassificationModel) (pipelineModel.stages()[2]);
System.out.println("Learned classification tree model depth"
+treeModel.depth()+" numNodes "+treeModel.numNodes());
//+ treeModel.toDebugString()); //输出整个决策树规则集
}//maxDepth循环
}
}
评估分类模型在测试集上的表现:
entropy:
MaxDepth is: 2
accuracy is: 0.896551724137931
Learned classification tree model depth2 numNodes 5
MaxDepth is: 3
accuracy is: 0.9310344827586207
Learned classification tree model depth3 numNodes 7
MaxDepth is: 4
accuracy is: 0.9310344827586207
Learned classification tree model depth4 numNodes 9
MaxDepth is: 5
accuracy is: 0.9310344827586207
Learned classification tree model depth5 numNodes 11
MaxDepth is: 6
accuracy is: 0.9310344827586207
Learned classification tree model depth5 numNodes 11
Gini:
MaxDepth is: 2
accuracy is: 0.8928571428571429
Learned classification tree model depth2 numNodes 5
MaxDepth is: 3
accuracy is: 0.9642857142857143
Learned classification tree model depth3 numNodes 9
MaxDepth is: 4
accuracy is: 0.9285714285714286
Learned classification tree model depth4 numNodes 13
MaxDepth is: 5
accuracy is: 0.9285714285714286
Learned classification tree model depth4 numNodes 13
MaxDepth is: 6
accuracy is: 0.9285714285714286
Learned classification tree model depth4 numNodes 13
另外:treeModel.toDebugString()将获得类似下面的规则集:
MaxDepth is: 3
accuracy is: 0.9666666666666667
Learned classification tree model depthDecisionTreeClassificationModel
(uid=dtc_62e3aea12022) of depth 3 with 9 nodes
If (feature 2 <= -0.694915)
Predict: 0.0
Else (feature 2 > -0.694915)
If (feature 3 <= 0.25)
If (feature 2 <= 0.322034)
Predict: 2.0
Else (feature 2 > 0.322034)
Predict: 1.0
Else (feature 3 > 0.25)
If (feature 2 <= 0.288136)
Predict: 1.0
Else (feature 2 > 0.288136)
Predict: 1.0
结论:
- 提高树的深度一般可以得到更精确的模型,但是深度越大,模型越复杂,对训练数据集的过拟合程度越严重。
- 两种不纯度方法对性能的差异影响
(上述结论参考了《Spark机器学习》 Machine Learning with Spark一书,113页)
本文参考了以下博客:
(1)http://blog.csdn.net/gumpeng/article/details/51397737
Python版本非常详细的注释,以及最为通俗易懂的例子。
(2 ID3:
http://blog.csdn.net/acdreamers/article/details/44661149
(3)Scikit决策树:
http://blog.csdn.net/sandyzhs/article/details/46814805
(4)基尼不纯度与熵的差别
http://blog.csdn.net/lingtianyulong/article/details/34522757
(5)算法杂货铺——分类算法之决策树(Decision tree)
http://www.cnblogs.com/leoo2sk/archive/2010/09/19/decision-tree.html
(6)决策树剪枝
http://blog.sina.com.cn/s/blog_4e4dec6c0101fdz6.html
(7)Opencv2.4.9源码分析——Decision Trees
http://blog.csdn.net/zhaocj/article/details/50503450
里面非常详细的代码说明和原理说明
其它算法
后剪枝有多种计算方法,这里分析一种比较简单的算法:
伪代码:
基于已有的树切分测试数据集:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算将当前两个叶节点合并后的误差(1)
计算不合并的误差
如果合并后会降低误差的话,就将叶节点合并
这里(1)所说的误差是:计算每条数据的值与均值的差的平方,最后其求和,即平
方误差的总值。这个值是混乱程度的一种表示方法。越混乱,值应该越大。所以误
差的降低,就变得更“纯”了,所以“如果合并后会降低误差的话,就将叶节点合并”。
以上是关于决策树的几种类型差异及Spark 2.0-MLlibScikit代码分析的主要内容,如果未能解决你的问题,请参考以下文章