scikit学习决策树导出graphviz - 决策树中的错误类名

Posted

技术标签:

【中文标题】scikit学习决策树导出graphviz - 决策树中的错误类名【英文标题】:scikit learn decision tree export graphviz - wrong class names in the decision tree 【发布时间】:2017-05-03 15:20:15 【问题描述】:

我在“scikit learn/decision tree/export graphviz”的决策树中得到错误的类名。程序如下:

import matplotlib.pyplot as plt
import matplotlib.image as img
import pydot
from sklearn import tree

digital_table = [[0, 0], [0, 1], [1, 0], [1, 1]]
digital_label = ['zero', 'one', 'two', 'three']
digital_name = ['idx-1', 'idx-2']

digital_tree = tree.DecisionTreeClassifier()
digital_tree.fit(digital_table, digital_label)

with open("digital.dot", 'w') as f:
    f = tree.export_graphviz(digital_tree, 
                            feature_names=digital_name,
                            class_names=digital_label,
                            filled=True, rounded=True,
                            out_file=f)
(graph,) = pydot.graph_from_dot_file("digital.dot")
graph.write_png("digital.png")

plt.imshow(img.imread('digital.png'))
plt.show()

输出如下:

问题在于叶子中显示的类名。例如,如果 idx-1 为 1 且 idx-2 为 1,则绿色框应标记为“三”。但是,图像显示标签为“一”。谁能给你的cmets?

【问题讨论】:

万一还有人有疑问,问题是模型是针对字符串标签['zero', 'one', 'two', 'three']进行训练的。该函数不知道哪一个调用零,哪一个调用一。所以它最终以字母顺序使用它们one 变为 0,three 变为 2,依此类推。处理这个问题的最好方法是将标签转换为整数类[0, 1, 2, 3] 【参考方案1】:

当您使用DecisionTreeClassifier时,您应该将类​​标签更改为0,1,2之类的数字

然后使用:

classe_names = decision_tree_classifier.classes_

它将按升序为您提供类的标签。然后以相同的顺序指定您的 class_label。可以是字符串。

【讨论】:

【参考方案2】:

在将类标签传递给export_graphviz之前,尝试按字母顺序对它们进行排序

【讨论】:

感谢 cmets。但是,我认为表格元素的顺序和标签元素的顺序应该是同步的。对吗?

以上是关于scikit学习决策树导出graphviz - 决策树中的错误类名的主要内容,如果未能解决你的问题,请参考以下文章

更改使用导出 graphviz 创建的决策树图的颜色

从 scikit-learn 中的文件加载决策树

使用 Graphviz 显示此决策树

在给出错误的决策树期间导出 graphviz

Scikit决策树分类特征

如何将python生成的决策树利用graphviz画出来