使用 rpart 生成 sankey 图的决策树

Posted

技术标签:

【中文标题】使用 rpart 生成 sankey 图的决策树【英文标题】:Decision tree using rpart to produce a sankey diagram 【发布时间】:2019-02-11 14:39:24 【问题描述】:

我可以使用作为基础 R 一部分的 Kyphosis 数据集创建带有 Rpart 的树:

fit <- rpart(Kyphosis ~ Age + Number + Start,
         method="class", data=kyphosis)
printcp(fit)
plot(fit, uniform=TRUE,main="Classification Tree for Kyphosis")
text(fit, use.n=TRUE, all=TRUE, cex=.8)

这是树的样子:

现在为了更好地可视化树,我想使用 plotly 使用 sankey 图。要在 plotly 中创建 sankey 图,必须执行以下操作:

library(plotly)
nodes=c("Start>=8.5","Start>-14.5","absent",
                   "Age<55","absent","Age>=111","absent","present","present")
p <- plot_ly(
  type = "sankey",
  orientation = "h",      
  node = list(
    label = nodes,
    pad = 10,
    thickness = 20,
    line = list(
      color = "black",
      width = 0.5
    )
  ),

  link = list(
    source = c(0,1,1,3,3,5,5,0),
    target = c(1,2,3,4,5,6,7,8),
    value =  c(1,1,1,1,1,1,1,1)
  )
) %>% 
  layout(
    title = "Desicion Tree",
    font = list(
      size = 10
    )
  )
p

这将创建一个与树相对应的桑基图(硬编码)。所需的三个必要向量是“源”、“目标”、“值”,如下所示:

硬编码桑基图:

我的问题是使用 rpart 对象 'fit' 我似乎无法轻松获得一个向量来为 plotly 生成所需的 'source'、'target' 和 'value' 向量。

fit$frame 和 fit$splits 包含一些信息,但很难将它们汇总或一起使用。在 fit 对象上使用 print 功能会产生所需的信息,但我不想进行文本编辑来获取它。

print(fit)

输出:

1) root 81 17 absent (0.79012346 0.20987654)  
   2) Start>=8.5 62  6 absent (0.90322581 0.09677419)  
     4) Start>=14.5 29  0 absent (1.00000000 0.00000000) *
     5) Start< 14.5 33  6 absent (0.81818182 0.18181818)  
      10) Age< 55 12  0 absent (1.00000000 0.00000000) *
      11) Age>=55 21  6 absent (0.71428571 0.28571429)  
        22) Age>=111 14  2 absent (0.85714286 0.14285714) *
        23) Age< 111 7  3 present (0.42857143 0.57142857) *
   3) Start< 8.5 19  8 present (0.42105263 0.57894737) *

那么有没有一种简单的方法可以使用 rpart 对象来获取这 3 个向量以便 plotly 生成 sankey 图?该图将在 Web 应用程序中使用,因此必须使用 plotly,因为我们已经拥有与之对应的 javascript,并且它必须易于重用以应用于各种数据集。

【问题讨论】:

能否请您粘贴数据,以便我们可以轻松地重新创建fit 对象? 如果我没记错的话,Kyphosis 数据是基础 R 附带的“rpart”包的一部分。所以你可以按原样使用代码。 【参考方案1】:

这是我的尝试:

据我所知,挑战是生成nodessource变量。

样本数据:

fit <- rpart(Kyphosis ~ Age + Number + Start,
             method="class", data=kyphosis)

生成nodes:

frame <- fit$frame
isLeave <- frame$var == "<leaf>"
nodes <- rep(NA, length(isLeave))
ylevel <- attr(fit, "ylevels")
nodes[isLeave] <- ylevel[frame$yval][isLeave]
nodes[!isLeave] <- labels(fit)[-1][!isLeave[-length(isLeave)]]

生成source:

node <- as.numeric(row.names(frame))
depth <- rpart:::tree.depth(node)
source <- depth[-1] - 1

reps <- rle(source)
tobeAdded <- reps$values[sapply(reps$values, function(val) sum(val >= which(reps$lengths > 1))) > 0]
update <- source %in% tobeAdded
source[update] <- source[update] + sapply(tobeAdded, function(tobeAdd) rep(sum(which(reps$lengths > 1) <= tobeAdd), 2))

经测试:

library(rpart)
fit <- rpart(Kyphosis ~ Age + Number + Start,
             method="class", data=kyphosis)
fit2 <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis,
              parms = list(prior = c(.65,.35), split = "information"))

如何到达:

见:getS3method("print", "rpart")

【讨论】:

您好 BigDataScientist,感谢您的回复。您的代码包含一些对我的解决方案非常有用的东西。我不知道 getS3method 有很大帮助。我在 iris 数据集上尝试了您的解决方案:fit 嘿@BigDataScientist 我在我们的应用程序中发布了我们采用的“临时”解决方案(我不知道如何向您展示我们所做的)。但是仍然欢迎您更改您的用于 iris 数据集的工作。如果你让它工作,我可能会使用你的解决方案而不是我的解决方案,因为在我们的前端加载额外的库需要时间。但是再次感谢您,您的代码中肯定有非常有用的东西(如您所见,我已经在使用您的方法来获取节点向量)。 嘿,对不起,马特我病了(现在还在,...)。我希望代码至少可以帮助你一点,...【参考方案2】:

我暂时有一个临时解决方案。我只是不喜欢加载额外的库。但这里是: 为 Iris 数据集拟合模型:

fit <- rpart(Species~Sepal.Length +Sepal.Width   ,
         method="class", data=iris)

printcp(fit)
plot(fit, uniform=TRUE, 
     main="Classification Tree for IRIS")
text(fit, use.n=TRUE, all=TRUE, cex=.8)

我用来获取节点名称的方法是:

treeFrame=fit$frame
nodes=sapply(row.names(treeFrame),function(x) unlist(rpart::path.rpart(fit,x))
        [length(unlist(rpart::path.rpart(fit,x)))])

但在@BigDataScientist 解决方案中有更好的方法:

treeFrame=fit$frame
isLeave <- treeFrame$var == "<leaf>"
nodes <- rep(NA, length(isLeave))
ylevel <- attr(fit, "ylevels")
nodes[isLeave] <- ylevel[treeFrame$yval][isLeave]
nodes[!isLeave] <- labels(fit)[-1][!isLeave[-length(isLeave)]]

现在获取源和目标仍然有点棘手,但帮助我的是 rpart.utils 包:

library('rpart.utils')
treeFrame=fit$frame
treeRules=rpart.utils::rpart.rules(fit)

targetPaths=sapply(as.numeric(row.names(treeFrame)),function(x)  
                      strsplit(unlist(treeRules[x]),split=","))

lastStop=  sapply(1:length(targetPaths),function(x) targetPaths[[x]] 
                      [length(targetPaths[[x]])])

oneBefore=  sapply(1:length(targetPaths),function(x) targetPaths[[x]] 
                      [length(targetPaths[[x]])-1])


target=c()
source=c()
values=treeFrame$n
for(i in 2:length(oneBefore))

  tmpNode=oneBefore[[i]]
  q=which(lastStop==tmpNode)

  q=ifelse(length(q)==0,1,q)
  source=c(source,q)
  target=c(target,i)


source=source-1
target=target-1

所以我不喜欢使用额外的库,但这似乎适用于各种数据集。并且使用@BigDataScientist 获取节点的方式更好。但我仍然会寻找更好的解决方案。 @BigDataScientist 我认为您的解决方案可能会更好地工作,也许需要更改一些小东西。但我还不太了解您代码中的“reps”部分。

最后的情节代码是:

 p <- plot_ly(
 type = "sankey",
 orientation = "v",

 node = list(
     label = nodes,
     pad = 15,
     thickness = 20,
     line = list(
     color = "black",
     width = 0.5
     )
 ),

 link = list(
     source = source,
     target = target,
     value=values[-1]

 )
 ) %>% 
 layout(
     title = "Basic Sankey Diagram",
     font = list(
     size = 10
     )
 )
 p

【讨论】:

以上是关于使用 rpart 生成 sankey 图的决策树的主要内容,如果未能解决你的问题,请参考以下文章

Plotly:如何使用 pandas 数据框定义 sankey 图的结构?

从 rpart 包中的决策规则中提取信息

插入符号 rpart 决策树绘图结果

使用 rpart 为决策树修剪选择 CP 值

Rpart R决策树分数[重复]

rpart 不在 R 中创建决策树,SVM 有效