scikit-learn 将每个叶节点的决策标签保存在其树结构中的啥位置?
Posted
技术标签:
【中文标题】scikit-learn 将每个叶节点的决策标签保存在其树结构中的啥位置?【英文标题】:Where does scikit-learn hold the decision labels of each leaf node in its tree structure?scikit-learn 将每个叶节点的决策标签保存在其树结构中的什么位置? 【发布时间】:2017-10-24 20:01:12 【问题描述】:我已经使用 scikit-learn 训练了一个随机森林模型,现在我想将它的树结构保存在一个文本文件中,以便我可以在其他地方使用它。 根据this link,一个树对象由许多并行数组组成,每个数组都包含有关树的不同节点的一些信息(例如左孩子、右孩子、它检查的特征……)。但是似乎没有关于每个叶节点对应的类标签的信息!上面链接中提供的示例甚至都没有提到它。
有谁知道 scikit-learn 决策树结构中存储的类标签在哪里?
【问题讨论】:
【参考方案1】:查看sklearn.tree.DecisionTreeClassifier.tree_.value
的文档:
from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
clf.fit(iris.data, iris.target)
print(clf.classes_)
[0, 1, 2]
print(clf.tree_.value)
[[[ 50. 50. 50.]]
[[ 50. 0. 0.]]
[[ 0. 50. 50.]]
[[ 0. 49. 5.]]
[[ 0. 47. 1.]]
[[ 0. 47. 0.]]
[[ 0. 0. 1.]]
[[ 0. 2. 4.]]
[[ 0. 0. 3.]]
[[ 0. 2. 1.]]
[[ 0. 2. 0.]]
[[ 0. 0. 1.]]
[[ 0. 1. 45.]]
[[ 0. 1. 2.]]
[[ 0. 0. 2.]]
[[ 0. 1. 0.]]
[[ 0. 0. 43.]]]
clf.tree_.value
中的每一行“包含每个节点的常量预测值”(help(clf.tree_)
),它对应于索引到索引到clf.classes_
。
请参阅this answer 了解(几乎没有)更多详细信息。
【讨论】:
加上答案,对于这个数组中的每一行,你可以通过clf.classes_[np.argmax(value)]
得到预测的类标签。
@not_a_robot 谢谢。你解释得很完美。但是我仍然找不到文档中提到 clf.tree_.value 的位置。我想我不再需要它了,因为你的答案正是我想要的。
又一个小问题。看起来 clf.classes_ 给了我 [0,...,n-1] 的标签,不管我使用什么标签。我对吗?我期待 [1,...,n] 在我的情况下。
我相信标签是零索引的,这就是为什么它是 [0, n-1]。以上是关于scikit-learn 将每个叶节点的决策标签保存在其树结构中的啥位置?的主要内容,如果未能解决你的问题,请参考以下文章
scikit-learn决策树回归:检索叶子的所有样本(不是平均值)