决策树代码《机器学习实战》

Posted 自嗨锅

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了决策树代码《机器学习实战》相关的知识,希望对你有一定的参考价值。

22:45:17 2017-08-09

KNN算法简单有效,可以解决很多分类问题。但是无法给出数据的含义,就是一顿计算向量距离,然后分类。

决策树就可以解决这个问题,分类之后能够知道是问什么被划分到一个类。用图形画出来就效果更好了,这次没有学哪个画图的,下次。

这里只涉及信息熵的计算,最佳分类特征的提取,决策树的构建。剪枝没有学,这里没有。

  1 # -*- oding: itf-8 -*-
  2 
  3 ‘‘‘
  4 function: 《机器学习实战》决策树的代码,画图的部分没有写;
  5 note: 贴出来以后用方便一点~
  6 date: 2017.8.9
  7 ‘‘‘
  8 
  9 from numpy import *
 10 from math import log
 11 import operator
 12 
 13 #计算香浓信息熵
 14 def calcuEntropy(dataSet):
 15     numOfEntries = len(dataSet)
 16     featVec = {}
 17     for data in dataSet:
 18         currentLabel = data[-1]
 19         if currentLabel not in featVec.keys():
 20             featVec[currentLabel] = 1
 21         else:
 22             featVec[currentLabel] += 1
 23     shannonEntropy = 0.0
 24     for feat in featVec.keys():
 25         prob = float(featVec[feat]) / numOfEntries
 26         shannonEntropy += -prob*log(prob, 2) 
 27     return shannonEntropy
 28 
 29 #产生数据集
 30 def loadDataSet():
 31     dataSet = [[1,1,yes],
 32                 [1,0,no],
 33                 [0,1,no],
 34                 [0,1,no]]
 35     labels = [no surfacing, flippers]
 36     return dataSet, labels
 37 
 38 ‘‘‘
 39 function: split the dataset
 40 return: 基于划分特征划分之后我们想要的那部分集合
 41 parameters: dataSet: 数据集,axis: 要划分的特征, value:要返回的集合的axis特征值
 42 ‘‘‘
 43 def splitDataSet(dataSet, axis, value):
 44     retDataSet = [] #防止原始的数据集被修改
 45     for featVec in dataSet:
 46         if featVec[axis] == value: #我们想要的数值存起来,一会返回
 47             reducedFeatVec = featVec[:axis]
 48             reducedFeatVec.extend(featVec[axis+1:])
 49             retDataSet.append(reducedFeatVec)
 50     return retDataSet
 51 
 52 ‘‘‘
 53 function: 找出数据集中最佳的划分特征
 54 ‘‘‘
 55 def chooseBestClassifyFeat(dataSet):
 56     numOfFeatures = len(dataSet[0]) - 1
 57     bestFeature = -1  #初始化最佳的划分特征
 58     baseInfoGain = 0.0 #信息增益
 59     baseEntropy = calcuEntropy(dataSet)
 60     for i in range(numOfFeatures):
 61         # if numOfFeatures == 1: #错了,只有一个特征不是只有一个类别
 62         #     print(‘only one feature‘)
 63         #     print(dataSet[0][0])
 64         #     return dataSet[0][0] #只有一个特征直接返回该特征
 65         featList = [example[i] for example in dataSet] #或者第i个特征所有的取值
 66         unicVals = set(featList) #不重复的第i个特征取值
 67         newEntropy = 0.0
 68         for value in unicVals:
 69             subDataSet = splitDataSet(dataSet, i, value)
 70 
 71             #计算划分之后各个子数据集的信息熵,然后累加就是这个划分的信息熵
 72             currentEntropy = calcuEntropy(subDataSet) 
 73             prob = float(len(subDataSet)) / len(dataSet)
 74             newEntropy += prob * currentEntropy
 75         newInfoGain = baseEntropy - newEntropy
 76         if newInfoGain > baseInfoGain:
 77             bestFeature = i
 78             baseInfoGain = newInfoGain
 79     return bestFeature 
 80 
 81 ‘‘‘
 82 function: 多数表决,当分类器用完所有属性,叶节点还是类别不统一的时候调用这个函数
 83 arg: labelList 类别标签列表
 84 ‘‘‘
 85 def majorityCount(labelList):
 86     classCount = {}
 87     for label in labelList:
 88         if label not in classCount.keys():
 89             classCount[label] = 0
 90         classCount[label] += 1
 91     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),reverse = True)
 92     print(sortedClassCount)
 93     return sortedClassCount[0][0]
 94 
 95 
 96 ‘‘‘
 97 function: 递归的建造决策树
 98 arg: dataset: 数据集 labels: 代表特征的标签,起始算法不需要,比如fippers代表第一个特征的意义
 99 ‘‘‘
100 def createTree(dataSet, labels):
101     classList = [example[-1] for example in dataSet] #得到所有的类别
102     if classList.count(classList[0]) == len(classList): #只有一种类别,直接返回
103         return classList[0]
104     if len(dataSet[0]) == 1: #特征属性用完了但是还没有完全分开,多数表决
105         return majorityCount(classList)
106     bestFeat = chooseBestClassifyFeat(dataSet)
107     print(bestFeat =  + str(bestFeat))
108     bestFeatLabel = labels[bestFeat]
109     del(labels[bestFeat]) #删除这次使用的特征
110     featValues = [example[bestFeat] for example in dataSet]
111     myTree = {bestFeatLabel: {}}
112     unicVals = set(featValues)
113     for value in unicVals:
114         labelCopy = labels[:]
115         subDataSet = splitDataSet(dataSet, bestFeat, value)
116         myTree[bestFeatLabel][value] = createTree(subDataSet, labelCopy)
117     return myTree
118 
119 ‘‘‘
120 function: 用决策树进行分类
121 arg: inputTree: 训练好的决策树,featLabels: 特征标签,testVec: 待分类的向量
122 ‘‘‘
123 def classify(inputTree, featLabel, testVec):
124     firstStr = list(inputTree.keys())[0] #python3 dict,.keys()不支持索引,必须转换一下
125     secondDict = inputTree[firstStr] #second tree
126     featIndex = featLabel.index(firstStr) #可利用index函数找到这个特征标签对饮过的特征位置
127     for key in secondDict.keys():
128         if testVec[featIndex] == key:
129             if type(secondDict[key]).__name__ == dict: #说明下面不是叶子节点,继续分类
130                 classLabel = classify(secondDict[key], featLabel, testVec)
131             else:
132                 classLabel = secondDict[key] #到达叶子节点,直接返回类别标签
133     return classLabel
134 
135 ‘‘‘
136 function: 使用pickle模块持久化存储决策树
137 note:
138 ‘‘‘
139 def storeTree(inputTree, filename):
140     import pickle
141     fw = open(filename, wb)
142     pickle.dump(inputTree, fw)
143     fw.close()
144 
145 ‘‘‘
146 function: 从本地文件中读取决策树
147 ‘‘‘
148 def grabTree(filename):
149     import pickle
150     fr = open(filename,rb)
151     return pickle.load(fr)
152 
153 #测试信息熵的计算
154 dataSet, labels = loadDataSet()
155 shannon = calcuEntropy(dataSet)
156 print(shannon)
157 
158 #测试数据集分割
159 print(dataSet)
160 retDataSet = splitDataSet(dataSet, 1, 1)
161 print(retDataSet)
162 retDataSet = splitDataSet(dataSet, 1, 0)
163 print(retDataSet)
164 
165 #寻找最佳的划分特征
166 bestFeature = chooseBestClassifyFeat(dataSet)
167 print(bestFeature)
168 
169 #测试多数表决
170 out = majorityCount([1,1,2,2,2,1,2,2])
171 print(out)
172 
173 #创建决策大叔
174 myTree = createTree(dataSet, labels)
175 print(myTree)
176 
177 #测试分类器
178 dataSet, labels = loadDataSet()
179 classLabel = classify(myTree, labels, [0,1])
180 print(classLabel)
181 classLabel = classify(myTree, labels, [1,1])
182 print(classLabel)
183 
184 #持久化存储决策树
185 storeTree(myTree, classifierStorage.txt)
186 outTree = grabTree(classifierStorage.txt)
187 print(outTree)

 

以上是关于决策树代码《机器学习实战》的主要内容,如果未能解决你的问题,请参考以下文章

机器学习实战笔记--决策树

机器学习实战--第三章决策树

机器学习-决策树

机器学习实战笔记(Python实现)-02-决策树

机器学习实战教程:决策树实战篇

决策树应用