cart树回归及其剪枝的python实现
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了cart树回归及其剪枝的python实现相关的知识,希望对你有一定的参考价值。
前言
前文讨论的回归算法都是全局且针对线性问题的回归,即使是其中的局部加权线性回归法,也有其弊端(具体请参考前文)
采用全局模型会导致模型非常的臃肿,因为需要计算所有的样本点,而且现实生活中很多样本都有大量的特征信息。
另一方面,实际生活中更多的问题都是非线性问题。
针对这些问题,有了树回归系列算法。
回归树
在先前决策树的学习中,构建树是采用的 ID3 算法。在回归领域,该算法就有个问题,就是派生子树是按照所有可能值来进行派生。
因此 ID3 算法无法处理连续性数据。
故可使用二元切分法,以某个特定值为界进行切分。在这种切分法下,子树个数小于等于2。
除此之外,再修改择优原则香农熵 (因为数据变为连续型的了),便可将树构建成一棵可用于回归的树,这样一棵树便叫做回归树。
构建回归树的伪代码:
1 找到最佳的待切分特征: 2 如果该节点不能再分,将此节点存为叶节点。 3 执行二元切分 4 左右子树分别递归调用此函数
二元切分的伪代码:
1 对每个特征: 2 对每个特征值: 3 将数据集切成两份 4 计算切分误差 5 如果当前误差小于最小误差,则更新最佳切分以及最小误差。
特别说明,终止划分 (并直接建立叶节点)有三种情况:
1. 特征值划分完毕
2. 划分子集太小
3. 划分后误差改进不大
这几个操作被称做 "预剪枝"。
下面给出一个完整的回归树的小程序:
1 #!/usr/bin/env python 2 # -*- coding:UTF-8 -*- 3 4 ‘‘‘ 5 Created on 20**-**-** 6 7 @author: fangmeng 8 ‘‘‘ 9 10 from numpy import * 11 12 def loadDataSet(fileName): 13 ‘载入测试数据‘ 14 15 dataMat = [] 16 fr = open(fileName) 17 for line in fr.readlines(): 18 curLine = line.strip().split(‘\\t‘) 19 # 所有元素转换为浮点类型(函数编程) 20 fltLine = map(float,curLine) 21 dataMat.append(fltLine) 22 return dataMat 23 24 #============================ 25 # 输入: 26 # dataSet: 待切分数据集 27 # feature: 切分特征序号 28 # value: 切分值 29 # 输出: 30 # mat0,mat1: 切分结果 31 #============================ 32 def binSplitDataSet(dataSet, feature, value): 33 ‘切分数据集‘ 34 35 mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0] 36 mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0] 37 return mat0,mat1 38 39 #======================================== 40 # 输入: 41 # dataSet: 数据集 42 # 输出: 43 # mean(dataSet[:,-1]): 均值(也就是叶节点的内容) 44 #======================================== 45 def regLeaf(dataSet): 46 ‘生成叶节点‘ 47 48 return mean(dataSet[:,-1]) 49 50 #======================================== 51 # 输入: 52 # dataSet: 数据集 53 # 输出: 54 # var(dataSet[:,-1]) * shape(dataSet)[0]: 平方误差 55 #======================================== 56 def regErr(dataSet): 57 ‘计算平方误差‘ 58 59 return var(dataSet[:,-1]) * shape(dataSet)[0] 60 61 #======================================== 62 # 输入: 63 # dataSet: 数据集 64 # leafType: 叶子节点生成器 65 # errType: 误差统计器 66 # ops: 相关参数 67 # 输出: 68 # bestIndex: 最佳划分特征 69 # bestValue: 最佳划分特征值 70 #======================================== 71 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): 72 ‘选择最优划分‘ 73 74 # 获得相关参数中的最大样本数和最小误差效果提升值 75 tolS = ops[0]; 76 tolN = ops[1] 77 78 # 如果所有样本点的值一致,那么直接建立叶子节点。 79 if len(set(dataSet[:,-1].T.tolist()[0])) == 1: 80 return None, leafType(dataSet) 81 82 m,n = shape(dataSet) 83 # 当前误差 84 S = errType(dataSet) 85 # 最小误差 86 bestS = inf; 87 # 最小误差对应的划分方式 88 bestIndex = 0; 89 bestValue = 0 90 91 # 对于所有特征 92 for featIndex in range(n-1): 93 # 对于某个特征的所有特征值 94 for splitVal in set(dataSet[:,featIndex]): 95 # 划分 96 mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) 97 # 如果划分后某个子集中的个数不达标 98 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue 99 # 当前划分方式的误差 100 newS = errType(mat0) + errType(mat1) 101 # 如果这种划分方式的误差小于最小误差 102 if newS < bestS: 103 bestIndex = featIndex 104 bestValue = splitVal 105 bestS = newS 106 107 # 如果当前划分方式还不如不划分时候的误差效果 108 if (S - bestS) < tolS: 109 return None, leafType(dataSet) 110 # 按照最优划分方式进行划分 111 mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) 112 # 如果划分后某个子集中的个数不达标 113 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 114 return None, leafType(dataSet) 115 116 return bestIndex,bestValue 117 118 #======================================== 119 # 输入: 120 # dataSet: 数据集 121 # leafType: 叶子节点生成器 122 # errType: 误差统计器 123 # ops: 相关参数 124 # 输出: 125 # retTree: 回归树 126 #======================================== 127 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): 128 ‘构建回归树‘ 129 130 # 选择最佳划分方式 131 feat, val = chooseBestSplit(dataSet, leafType, errType, ops) 132 # feat为None的时候无需划分返回叶子节点 133 if feat == None: return val #if the splitting hit a stop condition return val 134 135 # 递归调用构建函数并更新树 136 retTree = {} 137 retTree[‘spInd‘] = feat 138 retTree[‘spVal‘] = val 139 lSet, rSet = binSplitDataSet(dataSet, feat, val) 140 retTree[‘left‘] = createTree(lSet, leafType, errType, ops) 141 retTree[‘right‘] = createTree(rSet, leafType, errType, ops) 142 143 return retTree 144 145 def test(): 146 ‘展示结果‘ 147 148 # 载入数据 149 myDat = loadDataSet(‘/home/fangmeng/ex0.txt‘) 150 # 构建回归树 151 myDat = mat(myDat) 152 153 print createTree(myDat) 154 155 156 if __name__ == ‘__main__‘: 157 test()
测试结果:
回归树的优化工作 - 剪枝
在上面的代码中,终止递归的条件中已经加入了重重的 "剪枝" 工作。
这些在建树的时候的剪枝操作通常被成为预剪枝。这是很有很有必要的,经过预剪枝的树几乎就是没有预剪枝树的大小的百分之一甚至更小,而性能相差无几。
而在树建立完毕之后,基于训练集和测试集能做更多更高效的剪枝工作,这些工作叫做 "后剪枝"。
可见,剪枝是一项较大的工作量,是对树非常关键的优化过程。
后剪枝过程的伪代码如下:
1 基于已有的树切分测试数据: 2 如果存在任一子集是一棵树,则在该子集上递归该过程。 3 计算将当前两个叶节点合并后的误差 4 计算不合并的误差 5 如果合并会降低误差,则将叶节点合并。
具体实现函数如下:
1 #=================================== 2 # 输入: 3 # obj: 判断对象 4 # 输出: 5 # (type(obj).__name__==‘dict‘): 判断结果 6 #=================================== 7 def isTree(obj): 8 ‘判断对象是否为树类型‘ 9 10 return (type(obj).__name__==‘dict‘) 11 12 #=================================== 13 # 输入: 14 # tree: 处理对象 15 # 输出: 16 # (tree[‘left‘]+tree[‘right‘])/2.0: 坍塌后的替代值 17 #=================================== 18 def getMean(tree): 19 ‘坍塌处理‘ 20 21 if isTree(tree[‘right‘]): tree[‘right‘] = getMean(tree[‘right‘]) 22 if isTree(tree[‘left‘]): tree[‘left‘] = getMean(tree[‘left‘]) 23 24 return (tree[‘left‘]+tree[‘right‘])/2.0 25 26 #=================================== 27 # 输入: 28 # tree: 处理对象 29 # testData: 测试数据集 30 # 输出: 31 # tree: 剪枝后的树 32 #=================================== 33 def prune(tree, testData): 34 ‘后剪枝‘ 35 36 # 无测试数据则坍塌此树 37 if shape(testData)[0] == 0: 38 return getMean(tree) 39 40 # 若左/右子集为树类型 41 if (isTree(tree[‘right‘]) or isTree(tree[‘left‘])): 42 # 划分测试集 43 lSet, rSet = binSplitDataSet(testData, tree[‘spInd‘], tree[‘spVal‘]) 44 # 在新树新测试集上递归进行剪枝 45 if isTree(tree[‘left‘]): tree[‘left‘] = prune(tree[‘left‘], lSet) 46 if isTree(tree[‘right‘]): tree[‘right‘] = prune(tree[‘right‘], rSet) 47 48 # 如果两个子集都是叶子的话,则在进行误差评估后决定是否进行合并。 49 if not isTree(tree[‘left‘]) and not isTree(tree[‘right‘]): 50 lSet, rSet = binSplitDataSet(testData, tree[‘spInd‘], tree[‘spVal‘]) 51 errorNoMerge = sum(power(lSet[:,-1] - tree[‘left‘],2)) +sum(power(rSet[:,-1] - tree[‘right‘],2)) 52 treeMean = (tree[‘left‘]+tree[‘right‘])/2.0 53 errorMerge = sum(power(testData[:,-1] - treeMean,2)) 54 if errorMerge < errorNoMerge: 55 return treeMean 56 else: return tree 57 else: return tree
模型树
这也是一种很棒的树回归算法。
该算法将所有的叶子节点不是表述成一个值,而是对叶子部分节点建立线性模型。比如可以是最小二乘法的基本线性回归模型。
这样在叶子节点里存放的就是一组线性回归系数了。非叶子节点部分构造就和回归树一样。
这个是上面建立回归树算法的函数头:
createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
对于模型树,只需要修改修改 leafType(叶节点构造器) 和 errType(误差分析器) 的实现即可,分别对应如下modelLeaf 函数和 modelErr 函数:
1 #========================= 2 # 输入: 3 # dataSet: 测试集 4 # 输出: 5 # ws,X,Y: 回归模型 6 #========================= 7 def linearSolve(dataSet): 8 ‘辅助函数,用于构建线性回归模型。‘ 9 10 m,n = shape(dataSet) 11 X = mat(ones((m,n))); 12 Y = mat(ones((m,1))) 13 X[:,1:n] = dataSet[:,0:n-1]; 14 Y = dataSet[:,-1] 15 xTx = X.T*X 16 if linalg.det(xTx) == 0.0: 17 raise NameError(‘系数矩阵不可逆‘) 18 ws = xTx.I * (X.T * Y) 19 return ws,X,Y 20 21 #======================= 22 # 输入: 23 # dataSet: 数据集 24 # 输出: 25 # ws: 回归系数 26 #======================= 27 def modelLeaf(dataSet): 28 ‘叶节点构造器‘ 29 30 ws,X,Y = linearSolve(dataSet) 31 return ws 32 33 #======================================= 34 # 输入: 35 # dataSet: 数据集 36 # 输出: 37 # sum(power(Y - yHat,2)): 平方误差 38 #======================================= 39 def modelErr(dataSet): 40 ‘误差分析器‘ 41 42 ws,X,Y = linearSolve(dataSet) 43 yHat = X * ws 44 return sum(power(Y - yHat,2))
回归树 / 模型树的使用
前面的工作主要介绍了两种树 - 回归树,模型树的构建,下面进一步学习如何利用这些树来进行预测。
当然,本质也就是递归遍历树。
下为遍历代码,通过修改参数设置要使用并传递进来的是回归树还是模型树:
1 #============================== 2 # 输入: 3 # model: 叶子 4 # inDat: 测试数据 5 # 输出: 6 # float(model): 叶子值 7 #============================== 8 def regTreeEval(model, inDat): 9 ‘回归树预测‘ 10 11 return float(model) 12 13 #============================== 14 # 输入: 15 # model: 叶子 16 # inDat: 测试数据 17 # 输出: 18 # float(X*model): 叶子值 19 #============================== 20 def modelTreeEval(model, inDat): 21 ‘模型树预测‘ 22 n = shape(inDat)[1] 23 X = mat(ones((1,n+1))) 24 X[:,1:n+1]=inDat 25 return float(X*model) 26 27 #============================== 28 # 输入: 29 # tree: 待遍历树 30 # inDat: 测试数据 31 # modelEval: 叶子值获取器 32 # 输出: 33 # 分类结果 34 #============================== 35 def treeForeCast(tree, inData, modelEval=regTreeEval): 36 ‘使用回归/模型树进行预测 (modelEval参数指定)‘ 37 38 # 如果非树类型,返回值。 39 if not isTree(tree): return modelEval(tree, inData) 40 41 # 左遍历 42 if inData[tree[‘spInd‘]] > tree[‘spVal‘]: 43 if isTree(tree[‘left‘]): return treeForeCast(tree[‘left‘], inData, modelEval) 44 else: return modelEval(tree[‘left‘], inData) 45 46 # 右遍历 47 else: 48 if isTree(tree[‘right‘]): return treeForeCast(tree[‘right‘], inData, modelEval) 49 else: return modelEval(tree[‘right‘], inData)
使用方法非常简单,将树和要分类的样本传递进去就可以了。如果是模型树就将分类函数 treeForeCast 的第三个参数改为modelTreeEval即可。
这里就不再演示实验具体过程了。
小结
1. 选择哪个回归方法,得看哪个方法的相关系数高。(可使用 corrcoef 函数计算)
2. 树的回归和分类算法其实本质上都属于贪心算法,不断去寻找局部最优解。
3. 关于回归的讨论就先告一段落,接下来将进入到无监督学习部分。
以上是关于cart树回归及其剪枝的python实现的主要内容,如果未能解决你的问题,请参考以下文章