是否可以在 scikit-learn 中打印决策树?
Posted
技术标签:
【中文标题】是否可以在 scikit-learn 中打印决策树?【英文标题】:Is it possible to print the decision tree in scikit-learn? 【发布时间】:2014-10-06 03:14:00 【问题描述】:有没有办法在 scikit-learn 中打印经过训练的决策树?我想为我的论文训练一个决策树,我想把树的图片放在论文中。这可能吗?
【问题讨论】:
【参考方案1】:有一种方法可以导出为graph_viz格式:http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html
所以来自在线文档:
>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>>
>>> clf = tree.DecisionTreeClassifier()
>>> iris = load_iris()
>>>
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.export_graphviz(clf,
... out_file='tree.dot')
然后您可以使用图形 viz 加载它,或者如果您安装了 pydot,那么您可以更直接地执行此操作:http://scikit-learn.org/stable/modules/tree.html
>>> from sklearn.externals.six import StringIO
>>> import pydot
>>> dot_data = StringIO()
>>> tree.export_graphviz(clf, out_file=dot_data)
>>> graph = pydot.graph_from_dot_data(dot_data.getvalue())
>>> graph.write_pdf("iris.pdf")
将生成一个 svg,无法在此处显示,因此您必须点击链接:http://scikit-learn.org/stable/_images/iris.svg
更新
自从我第一次回答这个问题以来,行为似乎发生了变化,现在它返回一个list
,因此你会收到这个错误:
AttributeError: 'list' object has no attribute 'write_pdf'
首先,当您看到这一点时,只需打印对象并检查该对象就值得了,您想要的很可能是第一个对象:
graph[0].write_pdf("iris.pdf")
感谢@NickBraunagel 的评论
【讨论】:
我收到此错误。AttributeError: 'list' object has no attribute 'write_pdf'
我该如何解决这个问题?
@EdChum 你能检查一下这个***.com/questions/48880557/…
@ErnestSoo(以及遇到您的错误的任何其他人:pydot.graph_from_dot_data()
返回所需的 graph
(pydot.Dot
对象)但它在 list
中返回它:所以,访问列表的第一个访问pydot.Dot
对象的对象:graph[0].write_pdf("iris.pdf")
@NickBraunagel 因为似乎很多人都收到了这个错误,所以我将把它作为更新添加,看起来这是我在 3 年前回答这个问题以来的一些行为变化,谢谢
除了测试数据,你会如何做同样的事情?【参考方案2】:
虽然我迟到了,但以下综合说明可能对其他想要显示决策树输出的人有用:
安装必要的模块:
-
安装
graphviz
。我使用了 conda 的安装包here
(建议在 pip install graphviz
上安装 pip
不安装
包括实际的 GraphViz executables)
通过 pip (pip install pydot
) 安装 pydot
将包含 .exe 文件(例如 dot.exe)的 graphviz 文件夹目录添加到您的环境变量 PATH 中
运行上面的 EdChum(注意:graph
是一个包含 pydot.Dot
对象的 list
):
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot
clf = tree.DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph[0].write_pdf("iris.pdf") # must access graph's first element
现在您将在环境的默认目录中找到“iris.pdf”
【讨论】:
【参考方案3】:我知道有 4 种绘制 scikit-learn 决策树的方法:
使用sklearn.tree.export_text
方法打印树的文本表示
使用sklearn.tree.plot_tree
方法绘图(需要matplotlib
)
使用sklearn.tree.export_graphviz
方法绘图(需要graphviz
)
使用dtreeviz
包进行绘图(需要dtreeviz
和graphviz
)
最简单的就是导出为文本表示。示例决策树如下所示:
|--- feature_2 <= 2.45
| |--- class: 0
|--- feature_2 > 2.45
| |--- feature_3 <= 1.75
| | |--- feature_2 <= 4.95
| | | |--- feature_3 <= 1.65
| | | | |--- class: 1
| | | |--- feature_3 > 1.65
| | | | |--- class: 2
| | |--- feature_2 > 4.95
| | | |--- feature_3 <= 1.55
| | | | |--- class: 2
| | | |--- feature_3 > 1.55
| | | | |--- feature_0 <= 6.95
| | | | | |--- class: 1
| | | | |--- feature_0 > 6.95
| | | | | |--- class: 2
| |--- feature_3 > 1.75
| | |--- feature_2 <= 4.85
| | | |--- feature_1 <= 3.10
| | | | |--- class: 2
| | | |--- feature_1 > 3.10
| | | | |--- class: 1
| | |--- feature_2 > 4.85
| | | |--- class: 2
如果你安装了matplotlib
,你可以用sklearn.tree.plot_tree
绘图:
tree.plot_tree(clf) # the clf is your decision tree model
示例输出类似于使用export_graphviz
得到的输出:
你也可以试试dtreeviz
包。它会给你更多的信息。例子:
您可以在这篇博文中找到 sklearn 决策树的不同可视化与代码 sn-ps 的比较:link。
【讨论】:
以上是关于是否可以在 scikit-learn 中打印决策树?的主要内容,如果未能解决你的问题,请参考以下文章