5-7 CART树的剪枝

Posted windmissing

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了5-7 CART树的剪枝相关的知识,希望对你有一定的参考价值。

CART树的剪枝算法

输入:剪枝前的CART树
输出:剪枝后的CART树

原理

令:
损失函数为:
C a ( T ) = C ( T ) + a ∣ T ∣ (1) C_a(T) = C(T) + a|T| \\tag 1 Ca(T)=C(T)+aT(1)

对树上的所有结点计算:
假如该结点不split,该结点的损失为:
C a ( t ) = C ( t ) + a ∣ T ∣ (2) C_a(t) = C(t) + a|T| \\tag 2 Ca(t)=C(t)+aT(2)
假如该结点split,则该结点的损失为:
C a ( T t ) = C ( T t ) + a ∣ T t ∣ (3) C_a(T_t) = C(T_t) + a|T_t| \\tag 3 Ca(Tt)=C(Tt)+aTt(3)
当a=0或a足够小时,有
C a ( T t ) < C a ( t ) (4) C_a(T_t) < C_a(t) \\tag 4 Ca(Tt)<Ca(t)(4)
即split得到的损失更小。这是肯定的,因为CART树的生成算法就是这样的定义的。因为split得到的损失更小,才会split生成子结点。
但当a慢慢增大,对一个特定的结点t来说,当a取到某个值时,会有
C a ( T t ) = C a ( t ) (5) C_a(T_t) = C_a(t) \\tag 5 Ca(Tt)=Ca(t)(5)
将公式(2)、(3)代码公式(5),得出:
a = C ( t ) − C ( T t ) ∣ T t ∣ − 1 a = \\frac C(t) - C(T_t)|T_t| - 1 a=Tt1C(t)C(Tt)
若对于某个结点来说,此时的a使得KaTeX parse error: \\tag works only in display equations,那么就应该剪枝了。

步骤

  1. 令当前树为 T 0 T_0 T0,表示未做过剪枝的CART树
  2. **自下而上【?】**地对所有结点计算g(t),并同时记录最小的g(t)作为当前的alpha
    g ( t ) = C ( t ) − C ( T t ) ∣ T t ∣ − 1 g(t) = \\frac C(t) - C(T_t)|T_t| - 1 g(t)=Tt1C(t)C(Tt)
  3. 对Tree中 g ( t ) = a g(t)=a g(t)=a的结点做剪枝,得到的新的树 T a T_a Ta
  4. 如果当前的树不只一个根结点,就循环2-3步
  5. 通过交叉验证选出最优的 T a T_a Ta

代码

def isLeaf(Node):
    # return type(Node).__name__ != 'dict'
    return not ('left' in Node.keys() or 'right' in Node.keys())
    
def calcAlphaList(Node):
    if isLeaf(Node):
        return
    costNotSplit = Node['gini']
    costSplit = Node['left']['gini'] + Node['right']['gini']
    alpha = (costNotSplit-costSplit)/(Node['Tt']-1)
    Node['alpha'] = alpha
    if alpha < calcAlphaList.bestAlpha:
        calcAlphaList.bestAlpha = alpha
    calcAlphaList(Node['left'])
    calcAlphaList(Node['right'])
    
def calcTt(Node):
    if isLeaf(Node):
        return 1
    Node['Tt'] = calcTt(Node['left']) + calcTt(Node['right'])
    return Node['Tt']

def cut(Node, alpha):
    if Node['alpha'] == alpha:
        Node.pop('left')
        Node.pop('right')
    else:
        cut(Node['left'], alpha)
        cut(Node['right'], alpha)
        
def prune(Tree):
    i = 0
    print ('i=',i,Tree)
    while not isLeaf(Tree):
        calcTt(Tree)
        calcAlphaList.bestAlpha = np.inf
        calcAlphaList(Tree)
        i += 1
        print ('i=',i,'alpha',calcAlphaList.bestAlpha)
        cut(Tree, calcAlphaList.bestAlpha)
        print (Tree)

这个算法没想通

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XLrfxPoQ-1586344462475)(http://windmissing.github.io/images_for_gitbook/LiHang-TongJiXueXiFangFa/6.png)]]

以上是关于5-7 CART树的剪枝的主要内容,如果未能解决你的问题,请参考以下文章

5-7 CART树的剪枝

cart树剪枝

决策树的剪枝

机器学习算法:cart剪枝

决策树ID3决策树C4.5决策树CARTCART树的生成树的剪枝从ID3到CART从决策树生成规则决策树优缺点

CART剪枝