决策树
Posted zlj843767688
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了决策树相关的知识,希望对你有一定的参考价值。
from math import log import operator """ 函数说明:计算给定数据集的经验熵(香农熵) Parameters: dataSet:数据集 Returns: shannonEnt:经验熵 Modify: 2018-03-12 """ def calcShannonEnt(dataSet): #返回数据集行数 numEntries=len(dataSet) #保存每个标签(label)出现次数的字典 labelCounts={} #对每组特征向量进行统计 for featVec in dataSet: currentLabel=featVec[-1] #提取标签信息 if currentLabel not in labelCounts.keys(): #如果标签没有放入统计次数的字典,添加进去 labelCounts[currentLabel]=0 labelCounts[currentLabel]+=1 #label计数 shannonEnt=0.0 #经验熵 #计算经验熵 for key in labelCounts: prob=float(labelCounts[key])/numEntries #选择该标签的概率 shannonEnt-=prob*log(prob,2) #利用公式计算 return shannonEnt #返回经验熵 """ 函数说明:按照给定特征划分数据集 Parameters: dataSet:待划分的数据集 axis:划分数据集的特征 value:需要返回的特征值 Returns: 无 Modify: 2018-03-13 """ def splitDataSet(dataSet,axis,value): #创建返回的数据集列表 retDataSet=[] #遍历数据集 for featVec in dataSet: if featVec[axis]==value: #去掉axis特征 reduceFeatVec=featVec[:axis] #将符合条件的添加到返回的数据集 reduceFeatVec.extend(featVec[axis+1:]) retDataSet.append(reduceFeatVec) #返回划分后的数据集 return retDataSet """ 函数说明:计算给定数据集的经验熵(香农熵) Parameters: dataSet:数据集 Returns: shannonEnt:信息增益最大特征的索引值 Modify: 2018-03-13 """ def chooseBestFeatureToSplit(dataSet): #特征数量 numFeatures = len(dataSet[0]) - 1 #计数数据集的香农熵 baseEntropy = calcShannonEnt(dataSet) #信息增益 bestInfoGain = 0.0 #最优特征的索引值 bestFeature = -1 #遍历所有特征 for i in range(numFeatures): # 获取dataSet的第i个所有特征 featList = [example[i] for example in dataSet] #创建set集合{},元素不可重复 uniqueVals = set(featList) #经验条件熵 newEntropy = 0.0 #计算信息增益 for value in uniqueVals: #subDataSet划分后的子集 subDataSet = splitDataSet(dataSet, i, value) #计算子集的概率 prob = len(subDataSet) / float(len(dataSet)) #根据公式计算经验条件熵 newEntropy += prob * calcShannonEnt((subDataSet)) #信息增益 infoGain = baseEntropy - newEntropy #打印每个特征的信息增益 print("第%d个特征的增益为%.3f" % (i, infoGain)) #计算信息增益 if (infoGain > bestInfoGain): #更新信息增益,找到最大的信息增益 bestInfoGain = infoGain #记录信息增益最大的特征的索引值 bestFeature = i #返回信息增益最大特征的索引值 return bestFeature """ 函数说明:统计classList中出现次数最多的元素(类标签) Parameters: classList:类标签列表 Returns: sortedClassCount[0][0]:出现次数最多的元素(类标签) Modify: 2018-03-13 """ def majorityCnt(classList): classCount={} #统计classList中每个元素出现的次数 for vote in classList: if vote not in classCount.keys(): classCount[vote]=0 classCount[vote]+=1 #根据字典的值降序排列 sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] """ 函数说明:创建决策树 Parameters: dataSet:训练数据集 labels:分类属性标签 featLabels:存储选择的最优特征标签 Returns: myTree:决策树 Modify: 2018-03-13 """ def createTree(dataSet,labels,featLabels):#创建树 #取分类标签(是否放贷:yes or no) classList=[example[-1] for example in dataSet] #如果类别完全相同,则停止继续划分 if classList.count(classList[0])==len(classList): return classList[0] #是否没有特征 if len(dataSet[0])==1: # 遍历完所有特征时返回出现次数最多的类标签 return majorityCnt(classList) #选择最优特征 bestFeat=chooseBestFeatureToSplit(dataSet) #最优特征的标签 bestFeatLabel=labels[bestFeat] featLabels.append(bestFeatLabel) #根据最优特征的标签生成树 myTree={bestFeatLabel:{}} #删除已经使用的特征标签 del(labels[bestFeat]) #得到训练集中所有最优特征的属性值 featValues=[example[bestFeat] for example in dataSet] #去掉重复的属性值 uniqueVls=set(featValues) #遍历特征,创建决策树 for value in uniqueVls: myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value), labels,featLabels) return myTree """ 函数说明:创建测试数据集 Parameters:无 Returns: dataSet:数据集 labels:分类属性 Modify: 2018-03-13 """ #用上面的决策树来测试 def classify(inputTree,featLabels,testVec): # 得到树中的第一个特征 global classLabel firstStr=list(inputTree.keys())[0] # 得到第一个对应的值 secondDict=inputTree[firstStr] # 得到树中第一个特征对应的索引 # index方法查找当前列表中第一个匹配firstStr变量的元素的索引 featIndex = featLabels.index(firstStr) # 遍历树 for key in secondDict.keys(): # 如果在secondDict[key]中找到testVec[featIndex] if testVec[featIndex]==key: # 判断secondDict[key]是否为字典 if type(secondDict[key]).__name__==‘dict‘: # 若为字典,递归的寻找testVec classLabel=classify(secondDict[key],featLabels,testVec) else: # 若secondDict[key]为标签值,则将secondDict[key]赋给classLabel classLabel=secondDict[key] # 返回类标签 return classLabel def createDataSet(): # 数据集 dataSet=[[0, 0, 0, 0, ‘no‘], [0, 0, 0, 1, ‘no‘], [0, 1, 0, 1, ‘yes‘], [0, 1, 1, 0, ‘yes‘], [0, 0, 0, 0, ‘no‘], [1, 0, 0, 0, ‘no‘], [1, 0, 0, 1, ‘no‘], [1, 1, 1, 1, ‘yes‘], [1, 0, 1, 2, ‘yes‘], [1, 0, 1, 2, ‘yes‘], [2, 0, 1, 2, ‘yes‘], [2, 0, 1, 1, ‘yes‘], [2, 1, 0, 1, ‘yes‘], [2, 1, 0, 2, ‘yes‘], [2, 0, 0, 0, ‘no‘]] #分类属性 labels=[‘年龄‘,‘有工作‘,‘有自己的房子‘,‘信贷情况‘] #返回数据集和分类属性 return dataSet,labels if __name__==‘__main__‘: dataSet,labels=createDataSet() copy_Featlabels=labels[:] featLabels=[] myTree=createTree(dataSet,labels,featLabels) print(myTree) while True: try: test_Feat= input(‘请输入四个数字,每个数字为0或1,用空号隔开:‘) test_Feat= list(map(int,test_Feat.split(‘ ‘))) print(‘测试数据类别为:‘+str(classify(myTree,copy_Featlabels,test_Feat))) except: break
以上是关于决策树的主要内容,如果未能解决你的问题,请参考以下文章
sklearn决策树算法DecisionTreeClassifier(API)的使用以及决策树代码实例 - 莺尾花分类