CART决策树----基尼指数划分
Posted 独行的喵
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了CART决策树----基尼指数划分相关的知识,希望对你有一定的参考价值。
文章目录
CART决策树----基尼指数划分
一.决策树算法的构建
一般的,一棵决策树包含一个根节点,若干个内部结点和若干个叶结点;叶结点对应于决策结果,其他每个结点则对应于一个属性测试;每个结点包含的样本集合根据测试属性的结果被划分到子结点中;根节点包含样本全集。从根节点到每个叶节点的路径对应了一个判定测试序列。决策树学习的目的是为了产生一棵泛化能力强,即处理未见例能力强的决策树,其基本流程遵循简单而直观的分而治之策略
————from 西瓜书
决策树算法伪代码:
决策树的生成是一个递归的过程,在决策树生成算法的过程中,有三种情形需要递归返回:
- (1)当前结点包含的样本全部属于同一类别,无需划分,例如当前结点全是正例(好瓜):设置为叶子结点,返回当前集合的类别
- (2)当前属性集为空,或是所有样本在所有属性上取值相同,无法划分,例如对所有类别属性的划分都已经结束,无法进一步划分,或者比如色泽青绿的瓜其他属性都一样,也不必继续划分。:把当前结点标记为叶子结点,并将其类别设置为该节点所含样本最多的类别。
- (3)当前结点包含的样本集合为空,不能划分:当前结点设置为叶子节点,单将其类别设置为父节点所含样本中最多的类别。
二.划分选择——基尼指数
在决策树的建立过程当中,涉及到很多对当前结点集合的划分操作,而如何选择最优划分属性是决策树算法的关键问题之一。
一般而言,随着划分过程的不断进行,我们希望决策树的分支节点所包含的样本尽可能属于同一类别,即结点的 纯度(purity) 越来越高。
基尼指数:CART决策树使用基尼指数(Gini index)来选择划分属性:
G
i
n
i
(
D
)
=
1
−
∑
k
=
1
∣
y
∣
p
k
2
Gini(D)=1-\\sum_k=1^|y|p_k^2
Gini(D)=1−k=1∑∣y∣pk2
直观来说,Gini(D)反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率,因此Gini(D)越小,则数据集D的纯度越高。
属性a的基尼指数定义为:
G
i
n
i
_
i
n
d
e
x
(
D
,
a
)
=
∑
v
=
1
V
∣
D
v
∣
∣
D
∣
G
i
n
i
(
D
v
)
Gini\\_index(D,a)=\\sum_v=1^V\\frac|D^v||D|Gini(D^v)
Gini_index(D,a)=v=1∑V∣D∣∣Dv∣Gini(Dv)
于是在候选属性集A中,选择哪个使得划分后基尼指数最小的属性作为最优划分属性,即:
a
∗
=
a
r
g
m
i
n
G
i
n
i
_
i
n
d
e
x
(
D
,
a
)
a*=arg\\space min \\space Gini\\_index(D,a)
a∗=arg min Gini_index(D,a)
西瓜数据集如下:
编号 | 色泽 | 根蒂 | 敲声 | 纹理 | 脐部 | 触感 | 好瓜 |
---|---|---|---|---|---|---|---|
0 | 青绿 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 是 |
1 | 乌黑 | 蜷缩 | 沉闷 | 清晰 | 凹陷 | 硬滑 | 是 |
2 | 乌黑 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 是 |
3 | 青绿 | 蜷缩 | 沉闷 | 清晰 | 凹陷 | 硬滑 | 是 |
4 | 浅白 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 是 |
5 | 青绿 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 软粘 | 是 |
6 | 乌黑 | 稍蜷 | 浊响 | 稍糊 | 稍凹 | 软粘 | 是 |
7 | 乌黑 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 硬滑 | 否 |
8 | 乌黑 | 稍蜷 | 沉闷 | 稍糊 | 稍凹 | 硬滑 | 否 |
9 | 青绿 | 硬挺 | 清脆 | 清晰 | 平坦 | 软粘 | 否 |
10 | 浅白 | 硬挺 | 清脆 | 模糊 | 平坦 | 硬滑 | 否 |
11 | 浅白 | 蜷缩 | 浊响 | 模糊 | 平坦 | 软粘 | 否 |
12 | 青绿 | 稍蜷 | 浊响 | 稍糊 | 凹陷 | 硬滑 | 否 |
13 | 浅白 | 稍蜷 | 沉闷 | 稍糊 | 凹陷 | 硬滑 | 否 |
14 | 乌黑 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 软粘 | 否 |
15 | 浅白 | 蜷缩 | 浊响 | 模糊 | 平坦 | 硬滑 | 否 |
16 | 青绿 | 蜷缩 | 沉闷 | 稍糊 | 稍凹 | 硬滑 | 否 |
我们来模拟一下第一次根据基尼指数选择最后划分属性的过程:
假如我们选择的训练集为以下编号的数据:
[0, 1, 2, 3, 5, 6, 9, 13, 14, 15, 16]
我们以对色泽的基尼指数计算为例:
色泽属性中对应的特征有:青绿,乌黑,浅白
在训练集中:
- 青绿的个数为5个,其中是好瓜的有3个
- 乌黑的个数为4个,其中是好瓜的有4个
- 浅白的个数为2个,其中是好瓜的有0个
据此,由基尼指数公式,我们可以计算D中的Gini(D,色泽):
( 1 − ( 3 5 ) 2 − ( 2 5 ) 2 ) ∗ 5 11 + ( 1 − ( 3 4 ) 2 − ( 1 4 ) 2 ) ∗ 4 11 = 0.35454545... (1-(\\frac35)^2-(\\frac25)^2)*\\frac511+(1-(\\frac34)^2-(\\frac14)^2)*\\frac411=0.35454545... (1−(53)2−(52)2)∗115+(1−(43)2−(41)2)∗114=0.35454545...
将训练集的各项属性的基尼指数计算得出后:
最终数值最小的“脐部”作为最优划分属性。
根据“脐部”属性特征的不同,按照脐部为:凹陷,稍凹,平坦,将数据集分为三个子集,也就是构建出决策树的三个子节点。再以每一个子节点为数据集,在排除脐部以外的属性集中,选择出下一个最优划分属性来进行进一步的划分或由递归返回条件变为叶子节点并得出分类标记。依次类推,最终将在递归的划分中创建出整颗决策树。其中将数据集分类再处理再分类的过程体现了分而治之的思想。
三.剪枝处理
剪枝(pruning)是决策树算法对付”过拟合“的主要手段
1.预剪枝
预剪枝是指在决策树生成过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能的提升,则停止划分并将当前结点标记为叶节点,其类型标记为当前结点数据集中总数最多的类别。
预剪枝可以时决策树很多分支不进行展开,降低了过拟合的风险,同时还显著减少了决策树的训练时间开销和测试时间开销。但另一方面,有些分支的当前划分虽然不能提升泛化性能,甚至可能导致泛化性能下降,但在其基础上进行的后续划分却有可能导致泛化性能显著提升。预剪枝基于贪心本质禁止这些分支展开,给预剪枝决策树带来了欠拟合的风险。
2.后剪枝
后剪枝是先从训练集生成一棵完整的决策树,然后自底向上地对非叶节点进行考察,若将该节点对应的子树替换成叶节点能带来决策树泛化性能的提升,则将该子树替换为叶节点。
后剪枝决策树通常比预剪枝决策树保留了更多的分支。一般情形下,后剪枝决策树的欠拟合风险很小,泛化性能往往优于预剪枝决策树。但是后剪枝过程是在生成完全决策树之后进行的,并且自底向上地对树中的所有非叶节点进行逐一考察,因此其训练时间开销比未剪枝决策树和预剪枝决策树都要大得多,也就算法的时间复杂度往往比较大。
四.算法代码
算法代码参考了博文:https://blog.csdn.net/m0_37822685/article/details/100055766
import numpy as np
import matplotlib.pyplot as plt
from pylab import *
import operator
# 特征字典,后面用到了好多次,干脆当全局变量了
featureDic =
'色泽': ['浅白', '青绿', '乌黑'],
'根蒂': ['硬挺', '蜷缩', '稍蜷'],
'敲声': ['沉闷', '浊响', '清脆'],
'纹理': ['清晰', '模糊', '稍糊'],
'脐部': ['凹陷', '平坦', '稍凹'],
'触感': ['硬滑', '软粘']
# ***********************画图***********************
# **********************start***********************
# 详情参见机器学习实战决策树那一章
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 没有这句话汉字都是口口
# mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, fontsize=20)
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解
createPlot.ax1.annotate(nodeTxt,
xy=parentPt,
xycoords="axes fraction",
xytext=centerPt,
textcoords="axes fraction",
va="center",
ha="center",
bbox=nodeType,
arrowprops=arrow_args,
fontsize=20)
def getNumLeafs(myTree): # 获取叶节点的数目
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree): # 获取树的层数
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, figsize=(600, 30), facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# ***********************画图***********************
# ***********************end************************
def getDataSet():
"""
get watermelon data set 3.0 alpha.
:return: 训练集合剪枝集以及特征列表。
"""
# 也可以直接从
dataSet = [
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
数据挖掘领域经典算法——CART算法