如何在 scikit-learn 的随机森林的 graphviz-graph 中找到一个类?
Posted
技术标签:
【中文标题】如何在 scikit-learn 的随机森林的 graphviz-graph 中找到一个类?【英文标题】:How to find a Class in the graphviz-graph of the Random Forest of scikit-learn? 【发布时间】:2019-09-02 02:48:19 【问题描述】:我用 10 个估计器训练了一个随机森林分类器。然后,我将所有带有 graphviz 的树形图保存为 dot 和 png 文件。 最后,我做 RandomForest.predict。
从预测的输出中,我选择了一个预测的类,并通过搜索点文件在图中搜索它们(只需使用 STRG+F 搜索纯文本,与另一个模型一起使用)。 但我找不到那个班级。 当我查看 png 文件时,我只在节点中看到一个类。 (我不能在这里显示图表)。 这很奇怪,因为如果没有任何具有不同类的节点,它就不会预测它们。
我的目标是跟踪数据对象如何预测其类别的路径。
以下是我的代码的相关部分:
rfclf = RandomForestClassifier(class_weight = 'balanced')
rfclf.fit(x,y)
输出:
RandomForestClassifier(bootstrap=True, class_weight='平衡', 标准='gini',max_depth=None,max_features='auto', max_leaf_nodes=无,min_impurity_decrease=0.0, min_impurity_split=无,min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=10,n_jobs=None,oob_score=False, random_state=None,verbose=0,warm_start=False)
estimator=rfclf.estimators_[8] #or [0],[1],[2],.....[9] because there are 10 estimators
# Export as dot-file
export_graphviz(estimator, out_file='Graphs/rfclf8.dot',
feature_names = x.columns,
class_names = y,
rounded = True, proportion = False,
precision = 2, filled = True)
# convert to PNG with system command (needs Graphviz)
from subprocess import call
call(['dot', '-Tpng', 'Graphs/rfclf8.dot', '-o', 'Graphs/rfclf8.png', '-Gdpi=600'])
#predict
rfclf.predict(dfP)
输出:array(['-不同的类-, dtype=object)
代码有问题吗?它适用于不同的数据集。
【问题讨论】:
我们不可能在没有看到图表或至少其中一部分(叶子)的情况下回答。将它们作为 PNG 文件上传 【参考方案1】:为了跟踪对特定样本进行分类所采用的路径,您应该使用 decision_path() 的 RandomForestClassifier。它从 scikit-learn 0.18.0 开始可用
示例代码可在 https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html
【讨论】:
以上是关于如何在 scikit-learn 的随机森林的 graphviz-graph 中找到一个类?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Python scikit-learn 中输出随机森林中每棵树的回归预测?
如何在 scikit-learn 中执行随机森林模型的交叉验证?
如何在 scikit-learn 的随机森林的 graphviz-graph 中找到一个类?