在 rpart 的节点中获取观察结果(即:CART)

Posted

技术标签:

【中文标题】在 rpart 的节点中获取观察结果(即:CART)【英文标题】:Getting the observations in a rpart's node (i.e.: CART) 【发布时间】:2016-08-13 09:51:56 【问题描述】:

我想检查到达 rpart 决策树中某个节点的所有观察结果。例如,在以下代码中:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit

n= 81 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 81 17 absent (0.79012346 0.20987654)  
   2) Start>=8.5 62  6 absent (0.90322581 0.09677419)  
     4) Start>=14.5 29  0 absent (1.00000000 0.00000000) *
     5) Start< 14.5 33  6 absent (0.81818182 0.18181818)  
      10) Age< 55 12  0 absent (1.00000000 0.00000000) *
      11) Age>=55 21  6 absent (0.71428571 0.28571429)  
        22) Age>=111 14  2 absent (0.85714286 0.14285714) *
        23) Age< 111 7  3 present (0.42857143 0.57142857) *
   3) Start< 8.5 19  8 present (0.42105263 0.57894737) *

我想查看节点 (5) 中的所有观察结果(即:Start>=8.5 & Start

有什么建议可以解决这个问题吗?

【问题讨论】:

【参考方案1】:

rpart 返回 rpart.object 元素,其中包含您需要的信息:

require(rpart)
fit2 <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit2

get_node_date <-function(nodeId,fit)
  
  fit$frame[toString(nodeId),"n"]



for (i in c(1,2,4,5,10,11,22,23,3) )
  cat(get_node_date(i,fit2),"\n")

【讨论】:

您无法通过此获得观察结果,而只能获得属于某个类别的观察次数【参考方案2】:

似乎没有这样的功能可以从特定节点提取观察结果。我将按如下方式解决它:首先确定您感兴趣的节点使用了哪些规则。您可以使用path.rpart。然后你可以一个接一个地应用规则来提取观察结果。

这种方法作为一个函数:

get_node_date <- function(tree = fit, node = 5)
  rule <- path.rpart(tree, node)
  rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE))
  ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all)
  kyphosis[ind,]
  

对于节点 5,您会得到:

get_node_date()

 node number: 5 
   root
   Start>=8.5
   Start< 14.5
   Kyphosis Age Number Start
2    absent 158      3    14
10  present  59      6    12
11  present  82      5    14
14   absent   1      4    12
18   absent 175      5    13
20   absent  27      4     9
23  present  96      3    12
26   absent   9      5    13
28   absent 100      3    14
32   absent 125      2    11
33   absent 130      5    13
35   absent 140      5    11
37   absent   1      3     9
39   absent  20      6     9
40  present  91      5    12
42   absent  35      3    13
46  present 139      3    10
48   absent 131      5    13
50   absent 177      2    14
51   absent  68      5    10
57   absent   2      3    13
59   absent  51      7     9
60   absent 102      3    13
66   absent  17      4    10
68   absent 159      4    13
69   absent  18      4    11
71   absent 158      5    14
72   absent 127      4    12
74   absent 206      4    10
77  present 157      3    13
78   absent  26      7    13
79   absent 120      2    13
81   absent  36      4    13

【讨论】:

【参考方案3】:

partykit 包也为此提供了一个固定的解决方案。您只需要将rpart 对象转换为party 类,以便使用其统一的接口来处理树。然后就可以使用data_party()函数了。

使用问题中的fit 并加载library("partykit"),您可以首先将rpart 树强制转换为party

pfit <- as.party(fit)
plot(pfit)

以您想要的方式提取数据只有两个小麻烦:(1) 原始拟合中的model.frame() 总是在强制中丢弃,需要手动重新附加。 (2) 对节点使用不同的编号方案。您现在需要节点 4(而不是 5)。

pfit$data <- model.frame(fit)
data4 <- data_party(pfit, 4)
dim(data4)
## [1] 33  5
head(data4)
##    Kyphosis Age Start (fitted) (response)
## 2    absent 158    14        7     absent
## 10  present  59    12        8    present
## 11  present  82    14        8    present
## 14   absent   1    12        5     absent
## 18   absent 175    13        7     absent
## 20   absent  27     9        5     absent

另一种方法是从节点 4 开始对子树进行子集化,然后从中获取数据:

pfit4 <- pfit[4]
plot(pfit4)

那么data_party(pfit4) 给你的和上面的data4 一样。而pfit4$data 为您提供没有(fitted) 节点和预测的(response) 的数据。

【讨论】:

如果您使用ptree$data &lt;- model.frame(eval(tree$call$data)),公式中未使用的变量将不会被删除 True...但仅当data 包含formula 中的所有变量时,情况不一定如此。使用model.frame(),您还可以获得转换后的变量,例如log()Surv()factor() 版本的变量,这些变量通常是动态创建的。 顺便说一句:as.party() 强制 rpart 对象现在默认保留数据!因此,您可以使用as.party(fit, data = TRUE)(这是新的默认值)或as.party(fit, data = FALSE)(对应于旧行为)。【参考方案4】:

另一种方法是,通过查找任何特定节点的所有终端节点并返回调用中使用的数据子集来工作。

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)

head(subset.rpart(fit, 5))
#    Kyphosis Age Number Start
# 2    absent 158      3    14
# 10  present  59      6    12
# 11  present  82      5    14
# 14   absent   1      4    12
# 18   absent 175      5    13
# 20   absent  27      4     9


subset.rpart <- function(tree, node = 1L) 
  data <- eval(tree$call$data, parent.frame(1L))
  wh <- sapply(as.integer(rownames(tree$frame)), parent)
  wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
  data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]


parent <- function(x) 
  if (x[1] != 1)
    c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x

【讨论】:

【参考方案5】:

原帖两年后,但可能对其他人有用。 rpart中训练观察的节点分配可以从$where获取:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit$where

作为一个函数:

get_node <- function(rpart.object=fit, data=kyphosis, node.number=5) 
  data[which(fit$where == node.number),]  

get_node()

但这仅适用于训练观察,不适用于新观察。

【讨论】:

以上是关于在 rpart 的节点中获取观察结果(即:CART)的主要内容,如果未能解决你的问题,请参考以下文章

将 rpart 规则导出到数据框并链接规则以训练数据

CART 决策树中的冲突拆分

如何遍历R中rpart对象的树结构?我需要获取与子树关联的所有节点,我该怎么做?

R中常用数据挖掘算法包

R中的ROC曲线使用rpart包?

有人可以解释一下 ID3 和 CART 算法之间的区别吗?