sklearn决策树的BFS遍历

Posted

技术标签:

【中文标题】sklearn决策树的BFS遍历【英文标题】:BFS traversal of sklearn decision tree 【发布时间】:2020-08-03 00:12:29 【问题描述】:

如何进行sklearn决策树的广度优先搜索遍历?

在我的代码中,我尝试了 sklearn.tree_ 库并使用了各种函数,例如 tree_.feature 和 tree_.threshold 来理解树的结构。但是这些函数是对树进行dfs遍历的,如果我想做bfs我应该怎么做呢?

假设

clf1 = DecisionTreeClassifier( max_depth = 2 )
clf1 = clf1.fit(x_train, y_train)

这是我的分类器,生成的决策树是

然后我使用以下函数遍历了树

def encoding(clf, features):
l1 = list()
l2 = list()

for i in range(len(clf.tree_.feature)):
    if(clf.tree_.feature[i]>=0):
        l1.append( features[clf.tree_.feature[i]])
        l2.append(clf.tree_.threshold[i])
    else:
        l1.append(None)
        print(np.max(clf.tree_.value))
        l2.append(np.argmax(clf.tree_.value[i]))

l = [l1 , l2]

return np.array(l)

产生的输出是

array([['地址', '年龄', 无, 无, '年龄', 无, 无], [0.5, 17.5, 2, 1, 15.5, 1, 1]], dtype=object) 其中第一个数组是节点的特征,或者如果它是叶节点,那么它被标记为无,第二个数组是特征节点的阈值,对于类节点它是类,但这是树的 dfs 遍历我想做 bfs 遍历我应该怎么做? 以上部分已回答。

我想知道我们能否将树存储到数组中,使其看起来像是一棵完整的二叉树,以便第 i 个节点的子节点存储在第 2i + 1 和 2i +2 索引处?

对于上面生成的树输出是 array([['address', 'age', None, None], [0.5, 15.5, 1, 1]], dtype=object)

但是想要的输出是

array([['address', None, 'age', None, None, None, None], [0.5, -1, 15.5, -1, -1, 1, 1]], dtype=object)

如果值在第一个数组中为无,在第二个数组中为 -1,则表示该节点不存在。所以这里年龄是地址的右孩子在 2 * 0 + 2 = 2 数组中的索引以及类似的年龄左右孩子分别在数组的 2 * 2 + 1 = 5th 索引和 2 * 2 + 2 = 6th 索引处找到。

【问题讨论】:

这能回答你的问题吗? Traversal of sklearn decision tree 是的,您已经解决了问题的遍历部分,现在我想以这样一种方式存储树,即第 i 个节点的子节点存储在数组的第 2i 个和第 2i +1 个位置。 @Dion 你能帮我解决这个问题吗? 请用预期的(示例)输入和输出更新问题。 @Dion 我添加了一个示例,如果您需要进一步说明,请告诉我 【参考方案1】:

这样的?

def reformat_tree(clf):
    tree = clf.tree_

    feature_out = np.full((2 ** tree.max_depth), -1, dtype=tree.feature.dtype)
    threshold_out = np.zeros((2 ** tree.max_depth), dtype=tree.threshold.dtype)

    stack = []
    stack.append((0, 0))

    while stack:
        current_node, new_node = stack.pop()

        feature_out[new_node] = tree.feature[current_node]
        threshold_out[new_node] = tree.threshold[current_node]

        left_child = tree.children_left[current_node]
        if left_child >= 0:
            stack.append((left_child, 2 * current_node + 1))

        right_child = tree.children_right[current_node]
        if right_child >= 0:
            stack.append((right_child, 2 * current_node + 2))

    return feature_out, threshold_out

我无法在你的树上测试它,因为你还没有给出重现它的方法,但它应该可以工作。

该函数以所需格式返回特征和阈值。特征值为-1表示节点不存在,-2表示节点为叶子。

这是通过遍历树并跟踪当前位置来实现的。

【讨论】:

以上是关于sklearn决策树的BFS遍历的主要内容,如果未能解决你的问题,请参考以下文章

机器学习-------sklearn决策树分析

决策树 Sklearn - 树的深度和准确性

详解决策树-决策树的优缺点 & 分类树在合成数集上的表现菜菜的sklearn课堂笔记

使用 sklearn,我如何找到决策树的深度?

python-sklearn数据拆分与决策树的实现

sklearn实现决策树算法