scikit-learn 将每个叶节点的决策标签保存在其树结构中的啥位置?

Posted

技术标签:

【中文标题】scikit-learn 将每个叶节点的决策标签保存在其树结构中的啥位置?【英文标题】:Where does scikit-learn hold the decision labels of each leaf node in its tree structure?scikit-learn 将每个叶节点的决策标签保存在其树结构中的什么位置? 【发布时间】:2017-10-24 20:01:12 【问题描述】:

我已经使用 scikit-learn 训练了一个随机森林模型,现在我想将它的树结构保存在一个文本文件中,以便我可以在其他地方使用它。 根据this link,一个树对象由许多并行数组组成,每个数组都包含有关树的不同节点的一些信息(例如左孩子、右孩子、它检查的特征……)。但是似乎没有关于每个叶节点对应的类标签的信息!上面链接中提供的示例甚至都没有提到它。

有谁知道 scikit-learn 决策树结构中存储的类标签在哪里?

【问题讨论】:

【参考方案1】:

查看sklearn.tree.DecisionTreeClassifier.tree_.value 的文档:

from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()

clf.fit(iris.data, iris.target)

print(clf.classes_)

[0, 1, 2]

print(clf.tree_.value)

[[[ 50.  50.  50.]]

 [[ 50.   0.   0.]]

 [[  0.  50.  50.]]

 [[  0.  49.   5.]]

 [[  0.  47.   1.]]

 [[  0.  47.   0.]]

 [[  0.   0.   1.]]

 [[  0.   2.   4.]]

 [[  0.   0.   3.]]

 [[  0.   2.   1.]]

 [[  0.   2.   0.]]

 [[  0.   0.   1.]]

 [[  0.   1.  45.]]

 [[  0.   1.   2.]]

 [[  0.   0.   2.]]

 [[  0.   1.   0.]]

 [[  0.   0.  43.]]]

clf.tree_.value 中的每一行“包含每个节点的常量预测值”(help(clf.tree_)),它对应于索引到索引到clf.classes_

请参阅this answer 了解(几乎没有)更多详细信息。

【讨论】:

加上答案,对于这个数组中的每一行,你可以通过clf.classes_[np.argmax(value)]得到预测的类标签。 @not_a_robot 谢谢。你解释得很完美。但是我仍然找不到文档中提到 clf.tree_.value 的位置。我想我不再需要它了,因为你的答案正是我想要的。 又一个小问题。看起来 clf.classes_ 给了我 [0,...,n-1] 的标签,不管我使用什么标签。我对吗?我期待 [1,...,n] 在我的情况下。 我相信标签是零索引的,这就是为什么它是 [0, n-1]。

以上是关于scikit-learn 将每个叶节点的决策标签保存在其树结构中的啥位置?的主要内容,如果未能解决你的问题,请参考以下文章

如何故意在 scikit-learn 中过度拟合决策树?

scikit-learn决策树回归:检索叶子的所有样本(不是平均值)

决策树算法

获取数据框中列的唯一值的计数,这些值最终出现在决策树的每个叶节点中?

《统计学习方法》读书笔记之决策树

Python数据挖掘决策树