sklearn 决策树:获取每个节点和叶子的记录(**有效**)

Posted

技术标签:

【中文标题】sklearn 决策树:获取每个节点和叶子的记录(**有效**)【英文标题】:sklearn decision tree: get records at each node and leaf (**efficently**) 【发布时间】:2021-12-25 14:24:37 【问题描述】:

我正在一些 pandas 数据框 X 上训练决策树分类器。

clf = DecisionTreeClassifier()
clf = clf.fit(X, y)

现在我遍历树 clf.tree_ 并想要获取属于该内部节点或叶的记录(最好作为数据框)。我目前所做的如下所示。

fn = [ X.columns[i] if i != TREE_UNDEFINED else "undefined!"  for i in clf.tree_.feature ]

def recurse(node, tmp):
    tree = clf.tree_
    if self.test_node(tmp):
        return
    
    if tree.feature[node] != TREE_UNDEFINED:
        mask = tmp[fn[node]] <= tree.threshold[node]
        recurse(tree.children_left[node], tmp[mask])
        recurse(tree.children_right[node], tmp[~mask])
    
recurse(0, X)

这显然是可行的,但是在为 10K 棵树执行此操作时,我发现使用分析器时,我的代码中有 95+% 用于拆分数据帧。对数据的拟合可能是 2%,剩下的就是我对每个节点的数据框所做的。

有没有更有效的方法来拆分数据?

我假设 DT 在内部必须拆分数据(我可以获得每个节点的记录数)。我可以以某种方式让它附加将 df 放在节点上吗?

** 更新 **

建议使用clf.decision_path(X).toarray()。在这个矩阵中,每一列j 代表一个节点,i 行中的1 表示它通过了该节点。

我尝试了几种“方法”来使用此矩阵获取每个节点的 df。所有这些都比我目前使用的幼稚方法慢。

Walk tree: default: 2.4888 s +- 0.01 s per loop (mean +- std. dev. of 10 runs, 50 loops each)
Walk tree: no recursion: 2.5427 s +- 0.07 s per loop (mean +- std. dev. of 10 runs, 50 loops each)
Walk tree: decision path Numpy : 16.5346 s +- 0.08 s per loop (mean +- std. dev. of 10 runs, 50 loops each)
Walk tree: decision path Scipy: 8.8154 s +- 0.56 s per loop (mean +- std. dev. of 10 runs, 50 loops each)
Walk tree: decision path Pandas: 28.3901 s +- 0.69 s per loop (mean +- std. dev. of 10 runs, 50 loops each)

对于使用此数组的最快方法 Scipy,我还尝试查看获取索引或部分 df 是否需要最多时间。

Walk tree: decision path Scipy: 5.3404 s +- 0.20 s per loop (mean +- std. dev. of 10 runs, 30 loops each)
Walk tree: decision path Scipy (take=False): 4.5698 s +- 0.27 s per loop (mean +- std. dev. of 10 runs, 30 loops each)

我还尝试将上面的基本递归更改为使用df.query(..),但这也较慢。

【问题讨论】:

【参考方案1】:

我相信

pd.DataFrame(clf.decision_path(X).toarray())

可能是你想要的。如果观察 i 通过树的节点 j,则结果中的条目 [i, j] 将为 1。 还有一个关于决策树结构的非常好的示例here 可能会有所帮助。

【讨论】:

不幸的是,像这样在节点上获取记录(行)比我当前的方法(相当)慢。根据我如何获得数组中每列j 的非零元素,它可能在 1.5 倍(Scipy 和 CSC 矩阵)到 3 倍(熊猫 df)之间。 只是为了澄清:我的回答中调用创建的输出是您想要的,但它只是需要很长时间?或者您是否需要对clf.decision_path(X) 返回的节点指示符矩阵进行额外处理才能获得您想要的最终结果?如果是这样,您能否发布一个包含所需输出的玩具数据集的简短示例? 我需要额外的处理。我需要每个节点的观察/记录,所以我需要通过j 列中的索引“切片”原始 df。使用矩阵对所有节点执行此操作(令人惊讶的是)比简单方法慢得多。我最终改变了我的算法,以便我可以根据决策树已经拥有的信息跳过一些节点并获得 2 倍的改进(还不够,但还可以)

以上是关于sklearn 决策树:获取每个节点和叶子的记录(**有效**)的主要内容,如果未能解决你的问题,请参考以下文章

我如何从决策树中预测 x_train 的位置获取叶子的节点号?

使用 sklearn,我如何找到决策树的深度?

决策树算法

如何找到决策树中每个叶子或节点的索引?

机器学习系列-决策树

如何从每个节点提取sklearn决策树规则到pandas布尔条件?