网上的版本好像好久都没更新了treePlotter是没有人用了么。今天学习的时候发现有些地方已经改了,我改的是在python 3.6 上的运行版本,需要导入matplotlib.pyplot
import matplotlib.pyplot as plt # 定义决策树决策结果属性 descisionNode = dict(boxstyle=‘sawtooth‘, fc=‘0.8‘) leafNode = dict(boxstyle=‘round4‘, fc=‘0.8‘) arrow_args = dict(arrowstyle=‘<-‘)
# myTree = {‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: ‘no‘, 1: ‘yes‘}}}} def plotNode(nodeTxt, centerPt, parentPt, nodeType): # nodeTxt为要显示的文本,centerNode为文本中心点, nodeType为箭头所在的点, parentPt为指向文本的点 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=‘axes fraction‘, xytext=centerPt, textcoords=‘axes fraction‘, va=‘center‘, ha=‘center‘, bbox=nodeType, arrowprops=arrow_args) # def createPlot(): # fig = plt.figure(1, facecolor=‘white‘) # fig.clf() # # createPlot.ax1为全局变量,绘制图像句柄 # # frameon表示是否绘制坐标轴矩形 # createPlot.ax1 = plt.subplot(111, frameon=False) # plotNode(‘a decision node‘, (0.5, 0.1), (0.1, 0.5), descisionNode) # plotNode(‘a leaf node‘, (0.8, 0.1), (0.3, 0.8), leafNode) # plt.show() # 这个是用来测试的 # -----------分割线------------- # 获取树的叶子数和树的深度 def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == ‘dict‘: numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] # 这个是改的地方,原来myTree.keys()返回的是dict_keys类,不是列表,运行会报错。有好几个地方这样 secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == ‘dict‘: thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth # ---------分割线------------- # 制图 def createPlot(inTree): fig = plt.figure(1, facecolor=‘white‘) fig.clf() axprops = {‘xticks‘: None, ‘yticks‘: None} createPlot.ax1 = plt.subplot(111, frameon=False) plotTree.totalW = float(getNumLeafs(inTree)) # 全局变量宽度 = 叶子数目 plotTree.totalD = float(getTreeDepth(inTree)) # 全局变量高度 = 深度 plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), ‘‘) plt.show() def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] # cntrPt文本中心点, parentPt指向文本中心的点 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, descisionNode) seconDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in seconDict.keys(): if type(seconDict[key]).__name__ == ‘dict‘: plotTree(seconDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(seconDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va=‘center‘, ha=‘center‘, rotation=30) # createPlot(myTree)
这个treePlotter导入了就可以把原来得到的决策树模型导入啦,而且要注意是以字典形式导入,所以保存和导入文件的时候最好用json。
发布5分钟之后,突然发现已经有人改过了,那就只算是个学习笔记吧 - -