R:从决策树中提取规则

Posted

技术标签:

【中文标题】R:从决策树中提取规则【英文标题】:R: Extracting Rules from a Decision Tree 【发布时间】:2021-12-16 15:39:18 【问题描述】:

我正在使用 R 编程语言。最近,我读到了一种新的决策树算法,称为“强化学习树”(RLT),据说它有可能将“更好的”决策树拟合到数据集。该库的文档可在此处获得:https://cran.r-project.org/web/packages/RLT/RLT.pdf

我尝试使用这个库在(著名的)鸢尾花数据集上运行分类决策树:

library(RLT)
data(iris)
fit = RLT(iris[,c(1,2,3,4)], iris$Species, model = "classification", ntrees = 1)

问题:从这里,是否有可能从这个决策树中提取“规则”?

例如,如果您使用 CART 决策树模型:

library(rpart)
library(rpart.plot)

fit <-rpart( Species ~. , data = iris)
rpart.plot(fit)

 rpart.rules(fit)

    Species  seto vers virg                                               
     setosa [1.00  .00  .00] when Petal.Length <  2.5                     
 versicolor [ .00  .91  .09] when Petal.Length >= 2.5 & Petal.Width <  1.8
  virginica [ .00  .02  .98] when Petal.Length >= 2.5 & Petal.Width >= 1.8

是否可以使用 RLT 库来做到这一点?我一直在阅读这个库的文档,似乎找不到提取决策规则的直接方法。我知道这个库通常是用来替代随机森林(没有决策规则) - 但我正在阅读这个算法的原始论文,他们指定 RLT 算法适合单个决策树(通过RLT 算法),然后像在随机森林中一样将它们聚合在一起。因此在某种程度上,RLT 算法能够拟合单个决策树——理论上应该有“决策规则”。

有谁知道如何提取这些规则?

谢谢!

参考资料:

https://www.researchgate.net/publication/277625959_Reinforcement_Learning_Trees

【问题讨论】:

自从doing it this way isn't working 以来,您是否看过tests 可能提供进入的大门,但考虑到what probability refers to which var,似乎有必要求助于理解supplementary materials - pdf。 感谢您的回复!我会研究这些链接! 测试可以揭示,但我会说我必须把补充的东西当作'他们知道他们在说什么'和我的头,但是三个类,静音 - 不是我们的东西想要拆分,强 - 找到最好的拆分位置,等等,但也许默认的调整参数(补充第 15 页)是有效的规则。 【参考方案1】:

规则以相对难以解释的表格格式存储在fit$FittedTrees[[1]]中。

我为您构建了一个相当长的函数,它将规则提取为数据框,并在需要时将树绘制为 ggplot。

RLT_tree <- function(RLT_obj, plot = TRUE)

  
  tree <- as.data.frame(t(RLT_obj$FittedTrees[[1]]))
  tree <- tree[c(2, 3, 5, 6, 8, 9, grep("Class\\d", names(tree)))]
  class_cols <- grep("Class\\d", names(tree))
  names(tree)[class_cols] <-
    RLT_obj$ylevels[1 + as.numeric(sub("Class(\\d+)", "\\1",
                                   names(tree)[class_cols]))]
  tree$variable <- RLT_obj$variablenames[tree$SplitVar1]
  tree$variable[is.na(tree$variable)] <- "(Leaf node)"
  tree$rule <- tree$variable
  tree$depth <- numeric(nrow(tree))
  tree$rightness <- numeric(nrow(tree))
  tree$group <- character(nrow(tree))
  
  walk_tree <- function(node, depth, rightness, node_label = "Start", group = "S")
  
    new_row <- tree[which(tree$Node == node),]
    new_row$depth <- depth
    new_row$rightness <- rightness
    left_label <- paste(new_row$variable, new_row$SplitValue, sep = " < ")
    right_label <- paste(new_row$variable, new_row$SplitValue, sep = " > ")
    new_row$variable <- paste(node_label, "\nn = ", new_row$NumObs)
    new_row$rule <- node_label
    if(is.nan(new_row$SplitValue)) 
      n_objs <- round(new_row[,class_cols] * new_row$NumObs)
      classify <- paste((names(tree)[class_cols])[n_objs != 0], 
                        n_objs[n_objs != 0],
                        collapse = "\n")
      new_row$variable <- paste(new_row$variable, classify, sep = "\n")
    
    new_row$group <- group
    tree[which(tree$Node == node),] <<- new_row
    if(!is.nan(new_row$SplitValue))
      walk_tree(new_row$NextLeft, depth + 1, rightness - 2/(depth/2 + 1), 
                left_label, paste(group, "L"))
      walk_tree(new_row$NextRight, depth + 1, rightness + 2/(depth/2 + 1), 
                right_label, paste(group, "R"))
    
  
  
  walk_tree(0, 0, 0)
  tree$depth <- max(tree$depth) - tree$depth
  tree$type <- is.nan(tree$NextLeft)
  tree$group <- substr(tree$group, 1, nchar(tree$group) - 1)

  if(plot)
  
  print(ggplot(tree, aes(rightness, depth)) + 
    geom_segment(aes(x = rightness, xend = rightness, 
                     y = depth, yend = depth - 1, alpha = type)) + 
    geom_line(aes(group = group)) +
    geom_label(aes(label = variable, fill = type), size = 4) + 
    theme_void() + 
    scale_x_continuous(expand = c(0, 1)) + 
    suppressWarnings(scale_alpha_discrete(range = c(1, 0)))  +
    theme(legend.position = "none"))
  
  tree$isLeaf <- is.nan(tree$NextLeft)
  tree[c(match(c("Node", "rule", "depth", "isLeaf"), names(tree)), class_cols)]

这允许:

df <- RLT_tree(fit, plot = TRUE)

df
#>    Node               rule depth isLeaf    setosa versicolor virginica
#> 1     0              Start     6  FALSE 0.3111111 0.34814815 0.3407407
#> 2     1  Sepal.Width < 3.2     5  FALSE 0.1573034 0.51685393 0.3258427
#> 3     2  Sepal.Width > 3.2     5  FALSE 0.6086957 0.02173913 0.3695652
#> 4     3 Sepal.Length < 5.4     4  FALSE 0.7000000 0.30000000 0.0000000
#> 5     4 Sepal.Length > 5.4     4   TRUE 0.0000000 0.57971014 0.4202899
#> 6     5 Petal.Length < 1.3     3   TRUE 1.0000000 0.00000000 0.0000000
#> 7     6 Petal.Length > 1.3     3  FALSE 0.6000000 0.40000000 0.0000000
#> 8     7 Petal.Length < 1.4     2   TRUE 1.0000000 0.00000000 0.0000000
#> 9     8 Petal.Length > 1.4     2  FALSE 0.5000000 0.50000000 0.0000000
#> 10    9 Petal.Length < 3.9     1  FALSE 0.7500000 0.25000000 0.0000000
#> 11   10 Petal.Length > 3.9     1   TRUE 0.0000000 1.00000000 0.0000000
#> 12   11 Sepal.Length < 4.9     0   TRUE 1.0000000 0.00000000 0.0000000
#> 13   12 Sepal.Length > 4.9     0   TRUE 0.0000000 1.00000000 0.0000000
#> 14   13  Petal.Width < 0.2     4   TRUE 1.0000000 0.00000000 0.0000000
#> 15   14  Petal.Width > 0.2     4  FALSE 0.3793103 0.03448276 0.5862069
#> 16   15 Sepal.Length < 5.7     3   TRUE 1.0000000 0.00000000 0.0000000
#> 17   16 Sepal.Length > 5.7     3  FALSE 0.0000000 0.05555556 0.9444444
#> 18   17  Sepal.Width < 3.3     2   TRUE 0.0000000 0.00000000 1.0000000
#> 19   18  Sepal.Width > 3.3     2  FALSE 0.0000000 0.08333333 0.9166667
#> 20   19 Petal.Length < 6.1     1  FALSE 0.0000000 0.11111111 0.8888889
#> 21   20 Petal.Length > 6.1     1   TRUE 0.0000000 0.00000000 1.0000000
#> 22   21 Sepal.Length < 6.3     0   TRUE 0.0000000 0.16666667 0.8333333
#> 23   22 Sepal.Length > 6.3     0   TRUE 0.0000000 0.00000000 1.0000000

为了在更一般的情况下展示这个作品,我们还可以这样做:

fit2 = RLT(mtcars[,1:3], factor(rownames(mtcars)), model = "classification", ntrees = 1)

df <- RLT_tree(fit2)

【讨论】:

@Allan Cameron:非常感谢您的回答!看起来棒极了!我会尝试自己研究“fit$FittedTrees[[1]]”的输出! (我很想知道您是如何弄清楚如何解释表格的)。您认为您提供的答案也适用于回归示例吗?再次感谢! @stats555 它有点适用于回归模型。不过,可能需要对其进行调整以在回归结果中包含每个节点的 NodeMean 值,并将其显示在图上。如果您遇到困难,我很乐意在某个时候尝试一下。 @AllanCameron 你真是疯了?哇

以上是关于R:从决策树中提取规则的主要内容,如果未能解决你的问题,请参考以下文章

如何从 scikit-learn 决策树中提取决策规则?

如何从 scikit-learn 决策树中提取决策规则?

如何从 scikit-learn 决策树中提取决策规则?

提取规则以预测决策树中的子节点或概率分数

CART 决策树中的冲突拆分

从 rpart 包中的决策规则中提取信息