ctree() - 如何获取每个终端节点的拆分条件列表?
Posted
技术标签:
【中文标题】ctree() - 如何获取每个终端节点的拆分条件列表?【英文标题】:ctree() - How to get the list of splitting conditions for each terminal node? 【发布时间】:2014-02-21 23:19:03 【问题描述】:我有来自ctree()
(party
包)的输出,如下所示。如何获取每个终端节点的拆分条件列表,如sns <= 0, dta <= 1; sns <= 0, dta > 1
等?
1) sns <= 0; criterion = 1, statistic = 14655.021
2) dta <= 1; criterion = 1, statistic = 3286.389
3)* weights = 153682
2) dta > 1
4)* weights = 289415
1) sns > 0
5) dta <= 2; criterion = 1, statistic = 1882.439
6)* weights = 245457
5) dta > 2
7) dta <= 6; criterion = 1, statistic = 1170.813
8)* weights = 328582
7) dta > 6
谢谢
【问题讨论】:
【参考方案1】:CtreePathFunc
函数重写为更像 Hadley-verse(我认为更易于理解)的方式。还处理分类变量。
library(magrittr)
readSplitter <- function(nodeSplit)
splitPoint <- nodeSplit$splitpoint
if("levels" %>% is_in(splitPoint %>% attributes %>% names))
splitPoint %>% attr("levels") %>% .[splitPoint]
else
splitPoint %>% as.numeric
hasWeigths <- function(ct, path, terminalNode, pathNumber)
ct %>%
nodes(pathNumber %>% equals(path %>% length) %>% ifelse(terminalNode, path[pathNumber + 1]) ) %>%
.[[1]] %>% use_series("weights") %>% as.logical %>% which
dataFilter <- function(ct, dts, path, terminalNode, pathNumber)
whichWeights <- hasWeigths(ct, path, terminalNode, pathNumber)
nodes(ct, path[pathNumber])[[1]][[5]] %>%
buildDataFilter(dts, whichWeights)
buildDataFilter <- function(nodeSplit, ...) UseMethod("buildDataFilter")
buildDataFilter.nominalSplit <-
function(nodeSplit, dts, whichWeights)
varName <- nodeSplit$variableName
includedLevels <- dts[ whichWeights
,varName] %>% unique
paste( varName, "=="
,includedLevels %>% paste(collapse = ", ") %>% paste0("", ., ""))
buildDataFilter.orderedSplit <-
function(nodeSplit, dts, whichWeights)
varName <- nodeSplit$variableName
splitter <- nodeSplit %>% readSplitter
dts[ whichWeights
,varName] %>%
is_weakly_less_than(splitter) %>%
all %>%
ifelse("<=" ,">") %>%
paste(varName, ., splitter)
readTerminalNodePaths <- function (ct, dts)
nodeWeights <- function(Node) nodes(ct, Node)[[1]]$weights
sgmnts <- ct %>% where %>% unique
nodesFirstTreeWeightIsOne <- function(node) nodes(ct, node)[[1]][2][[1]] == 1
# Take the inner nodes smaller than the selected terminal node
innerNodes <-
function(Node) setdiff( 1:(Node - 1)
,sgmnts[sgmnts < Node])
pathForTerminalNode <- function(terminalNode)
innerNodes(terminalNode) %>%
sapply(function(innerNode)
if(any(nodeWeights(terminalNode) & nodesFirstTreeWeightIsOne(innerNode))) innerNode
) %>%
unlist
# Find the splits criteria
sgmnts %>% sapply(function(terminalNode) #
path <- terminalNode %>% pathForTerminalNode
path %>% length %>% seq %>%
sapply(function(nodeNumber)
dataFilter(ct, dts, path, terminalNode, nodeNumber)
, simplify = FALSE) %>%
unlist %>% paste(collapse = " & ") %>%
data.frame(Node = terminalNode, Path = .)
, simplify = FALSE) %>%
Reduce(f = rbind)
测试
shiftFirstPart <- function(vctr, divideBy, proportion = .5)
vctr[vctr %>% length %>% multiply_by(proportion) %>% round %>% seq] %<>% divide_by(divideBy)
vctr
set.seed(11)
n <- 13000
gdt <-
data.frame( is_buyer = runif(n) %>% shiftFirstPart(1.5) %>% round %>% factor(labels = c("no", "yes"))
,age = runif(n) %>% shiftFirstPart(1.5) %>%
cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, ordered_result = TRUE, labels = c("low", "mid", "high"))
,city = runif(n) %>% shiftFirstPart(1.5) %>%
cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, labels = c("Chigaco", "Boston", "Memphis"))
,point = runif(n) %>% shiftFirstPart(1.2)
)
gct <- ctree( is_buyer ~ ., data = gdt)
readTerminalNodePaths(gct, gdt)
【讨论】:
这是一个非常酷的功能,我不确定我是否完全理解它是如何工作的,但是有没有一种简单的方法来修改它来获取 all 节点的路径而不是只是终端的? 抱歉这个愚蠢的问题,但是哪个库包含函数'where'? @Manuel Chirouze,party
.【参考方案2】:
如果您使用ctree()
的新推荐partykit
实现而不是旧的party
包,那么您可以使用函数.list.rules.party()
。这还没有正式导出,但可以用来提取所需的信息。
library("partykit")
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq)
partykit:::.list.rules.party(ct)
## 3 5
## "Temp <= 82 & Wind <= 6.9" "Temp <= 82 & Wind > 6.9 & Temp <= 77"
## 6 8
## "Temp <= 82 & Wind > 6.9 & Temp > 77" "Temp > 82 & Wind <= 10.3"
## 9
## "Temp > 82 & Wind > 10.3"
【讨论】:
【参考方案3】:由于我需要这个函数,但对于分类数据,我或多或少地回答了@JoãoDaniel 的问题(我只测试了分类预测变量),下一个函数:
# returns string w/o leading or trailing whitespace
# http://***.com/questions/2261079/how-to-trim-leading-and-trailing-whitespace-in-r
trim <- function (x) gsub("^\\s+|\\s+$", "", x)
getVariable <- function (x) sub("(.*?)[[:space:]].*", "\\1", x)
getSimbolo <- function (x) sub("(.*?)[[:space:]](.*?)[[:space:]].*", "\\2", x)
getReglaFinal = function(elemento)
x = as.data.frame(strsplit(as.character(elemento),";"))
Regla = apply(x,1, trim)
Regla = data.frame(Regla)
indice = as.numeric(rownames(Regla))
variable = apply(Regla,1, getVariable)
simbolo = apply(Regla,1, getSimbolo)
ReglaRaw = data.frame(Regla,indice,variable,simbolo)
cols <- c( 'variable' , 'simbolo' )
ReglaRaw$tipo_corte <- apply( ReglaRaw[ , cols ] ,1 , paste , collapse = "" )
#print(ReglaRaw)
cortes = unique(ReglaRaw$tipo_corte)
#print(cortes)
ReglaFinal = ""
for(i in 1:length(cortes))
#print("------------------------------------")
#print(cortes[i])
#print("ReglaRaw econtrada")
#print(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
maximo = max(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
#print(maximo)
tmp = as.character(ReglaRaw$Regla[ReglaRaw$indice==maximo])
if(ReglaFinal=="")
ReglaFinal = tmp
else
ReglaFinal = paste(ReglaFinal,tmp,sep="; ",collapse="; ")
return(ReglaFinal)
#getReglaFinal
CtreePathFuncAllCat <- function (ct)
ResulTable <- data.frame(Node = character(), Path = character())
for(Node in unique(where(ct)))
# Taking all possible non-Terminal nodes that are smaller than the selected terminal node
NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])
# Getting the weigths for that node
NodeWeights <- nodes(ct, Node)[[1]]$weights
# Finding the path
Path <- NULL
for (i in NonTerminalNodes)
if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
# Finding the splitting creteria for that path
Path2 <- SB <- NULL
variablesNombres <- array()
variablesPuntos <- list()
for(i in 1:length(Path))
n <- nodes(ct, Path[i])[[1]]
if(i == length(Path))
nextNodeID = Node
else
nextNodeID = Path[i+1]
vec_puntos = as.vector(n[[5]]$splitpoint)
vec_nombre = n[[5]]$variableName
vec_niveles = attr(n[[5]]$splitpoint,"levels")
index = 0
if((length(vec_puntos)!=length(vec_niveles)) && (length(vec_niveles)!=0) )
index = vec_puntos
vec_puntos = vector(length=length(vec_niveles))
vec_puntos[index] = TRUE
if(length(vec_niveles)==0)
index = vec_puntos
vec_puntos = n[[5]]$splitpoint
if(index==0)
if(nextNodeID==n$right$nodeID)
vec_puntos = !vec_puntos
else
vec_puntos = !!vec_puntos
if(i != 1)
for(j in 1:(length(Path)-1))
if(length(variablesNombres)>=j)
if( variablesNombres[j]==vec_nombre)
vec_puntos = vec_puntos*variablesPuntos[[j]]
vec_puntos = vec_puntos==1
SB = "="
else
if(nextNodeID==n$right$nodeID)
SB = ">"
else
SB = "<="
variablesPuntos[[i]] = vec_puntos
variablesNombres[i] = vec_nombre
if(length(vec_niveles)==0)
descripcion = vec_puntos
else
descripcion = paste(vec_niveles[vec_puntos],collapse=", ")
Path2 <- paste(c(Path2, paste(c(variablesNombres[i],SB,"",descripcion, ""),collapse=" ")
),
collapse = "; ")
# Output
ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
we = weights(ct)
c0 = as.matrix(where(ct))
c3 = sapply(we, function(w) sum(w))
c3 = as.matrix(unique(cbind(c0,c3)))
Counts = as.matrix(c3[,2])
c2 = drop(Predict(ct))
Means = as.matrix(unique(c2))
ResulTable = data.frame(ResulTable,Means,Counts)
ResulTable = ResulTable[ order(ResulTable$Means) ,]
ResulTable$TruePath = apply(as.data.frame(ResulTable$Path),1, getReglaFinal)
ResulTable2 = ResulTable
ResulTable2$SQL <- paste("WHEN ",gsub("\\'([-+]?([0-9]*\\.[0-9]+|[0-9]+))\\'", "\\1",gsub("\\, ", "','", gsub(" \\", "')", gsub("\\ ", "('", gsub("\\;", " AND ", ResulTable2$TruePath)))))," THEN ")
cols <- c( 'SQL' , 'Node' )
ResulTable2$SQL <- apply( ResulTable2[ , cols ] ,1 , paste , collapse = "'Nodo " )
ResulTable2$SQL <- gsub("THEN'", "THEN '", gsub(" '", "'", paste(ResulTable2$SQL,"'")))
ResultadoFinal = list()
ResultadoFinal$PreTable = ResulTable
ResultadoFinal$Table = ResulTable
ResultadoFinal$Table$Path = ResultadoFinal$Table$TruePath
ResultadoFinal$Table$TruePath = NULL
ResultadoFinal$SQL = paste(" CASE ",paste(ResulTable2$SQL,sep="",collapse=" ")," END ",collapse="")
return(ResultadoFinal)
#CtreePathFuncAllCat
这是一个测试:
library(party)
#With ordered factors
TreeModel1 = ctree(PB~ME+SYMPT+HIST+BSE+DECT, data = mammoexp)
Result2 <- CtreePathFuncAllCat(TreeModel1)
Result2
##$PreTable
## Node Path Means Counts
##3 7 DECT > Somewhat likely ; SYMPT > Disagree 6.526316 114
##2 6 DECT > Somewhat likely ; SYMPT <= Disagree 7.640000 175
##1 4 DECT <= Somewhat likely ; DECT > Not likely 8.161905 105
##4 3 DECT <= Somewhat likely ; DECT <= Not likely 9.833333 18
## TruePath
##3 DECT > Somewhat likely ; SYMPT > Disagree
##2 DECT > Somewhat likely ; SYMPT <= Disagree
##1 DECT <= Somewhat likely ; DECT > Not likely
##4 DECT <= Not likely
##
##$Table
## Node Path Means Counts
##3 7 DECT > Somewhat likely ; SYMPT > Disagree 6.526316 114
##2 6 DECT > Somewhat likely ; SYMPT <= Disagree 7.640000 175
##1 4 DECT <= Somewhat likely ; DECT > Not likely 8.161905 105
##4 3 DECT <= Not likely 9.833333 18
##
##$SQL
##[1] " CASE WHEN DECT > ('Somewhat likely') AND SYMPT > ('Disagree') THEN 'Nodo 7' WHEN DECT > ('Somewhat likely') AND SYMPT <= ('Disagree') THEN 'Nodo 6' WHEN DECT <= ('Somewhat likely') AND DECT > ('Not likely') THEN 'Nodo 4' WHEN DECT <= ('Not likely') THEN 'Nodo 3' END "
#With unordered factors
TreeModel2 = ctree(count~spray, data = InsectSprays)
plot(TreeModel2, type="simple")
Result2 <- CtreePathFuncAllCat(TreeModel2)
Result2
##$PreTable
##Node Path Means Counts TruePath
##2 5 spray = C, D, E ; spray = C, E 2.791667 24 spray = C, E
##3 4 spray = C, D, E ; spray = D 4.916667 12 spray = D
##1 2 spray = A, B, F 15.500000 36 spray = A, B, F
##
##$Table
##Node Path Means Counts
##2 5 spray = C, E 2.791667 24
##3 4 spray = D 4.916667 12
##1 2 spray = A, B, F 15.500000 36
##
##$SQL
##[1] " CASE WHEN spray = ('C','E') THEN 'Nodo 5' WHEN spray = ('D') THEN 'Nodo 4' WHEN spray = ('A','B','F') THEN 'Nodo 2' END "
#With continuous variables
airq <- subset(airquality, !is.na(Ozone))
TreeModel3 <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3))
Result2 <- CtreePathFuncAllCat(TreeModel3)
Result2
##$PreTable
## Node Path Means Counts
##1 5 Temp <= 82 ; Wind > 6.9 ; Temp <= 77 18.47917 48
##3 6 Temp <= 82 ; Wind > 6.9 ; Temp > 77 31.14286 21
##4 9 Temp > 82 ; Wind > 10.3 48.71429 7
##2 3 Temp <= 82 ; Wind <= 6.9 55.60000 10
##5 8 Temp > 82 ; Wind <= 10.3 81.63333 30
## TruePath
##1 Temp <= 77 ; Wind > 6.9
##3 Temp <= 82 ; Wind > 6.9 ; Temp > 77
##4 Temp > 82 ; Wind > 10.3
##2 Temp <= 82 ; Wind <= 6.9
##5 Temp > 82 ; Wind <= 10.3
##
##$Table
## Node Path Means Counts
##1 5 Temp <= 77 ; Wind > 6.9 18.47917 48
##3 6 Temp <= 82 ; Wind > 6.9 ; Temp > 77 31.14286 21
##4 9 Temp > 82 ; Wind > 10.3 48.71429 7
##2 3 Temp <= 82 ; Wind <= 6.9 55.60000 10
##5 8 Temp > 82 ; Wind <= 10.3 81.63333 30
##
##$SQL
##[1] " CASE WHEN Temp <= (77) AND Wind > (6.9) THEN 'Nodo 5' WHEN Temp <= (82) AND Wind > (6.9) AND Temp > (77) THEN 'Nodo 6' WHEN Temp > (82) AND Wind > (10.3) THEN 'Nodo 9' WHEN Temp <= (82) AND Wind <= (6.9) THEN 'Nodo 3' WHEN Temp > (82) AND Wind <= (10.3) THEN 'Nodo 8' END "
更新!现在该函数支持分类变量和数值变量的混合!
【讨论】:
很好用,但是,它似乎只适用于分类变量:当我在 airct 树 CtreePathFuncAllCat(ct) 的结果上尝试这个时,它返回拆分字段,但不返回拆分标准。知道如何获取分类变量和连续变量的路径吗? @clevelandfrowns 我更新了函数,现在可以处理连续和分类数据。【参考方案4】:这个函数应该可以完成这项工作
CtreePathFunc <- function (ct, data)
ResulTable <- data.frame(Node = character(), Path = character())
for(Node in unique(where(ct)))
# Taking all possible non-Terminal nodes that are smaller than the selected terminal node
NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])
# Getting the weigths for that node
NodeWeights <- nodes(ct, Node)[[1]]$weights
# Finding the path
Path <- NULL
for (i in NonTerminalNodes)
if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
# Finding the splitting creteria for that path
Path2 <- SB <- NULL
for(i in 1:length(Path))
if(i == length(Path))
n <- nodes(ct, Node)[[1]]
else n <- nodes(ct, Path[i + 1])[[1]]
if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3])))
SB <- "<="
else SB <- ">"
Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
SB,
as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
collapse = ", ")
# Output
ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
return(ResulTable)
测试
library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)
Result
## Node Path
## 1 5 Temp <= 82, Wind > 6.9, Temp <= 77
## 2 3 Temp <= 82, Wind <= 6.9
## 3 6 Temp <= 82, Wind > 6.9, Temp > 77
## 4 9 Temp > 82, Wind > 10.3
## 5 8 Temp > 82, Wind <= 10.3
【讨论】:
耗时较长,但反应非常好。而且您忘记将“airq”矩阵作为变量。 谢谢,@Galled。已编辑。我也忘了library(party)
。这是我在 SO 中的第一个答案之一,所以那里有点菜鸟
这个函数是否有任何更新版本也可以处理分类解释变量? @DavidArenburg
@JoãoDaniel,我没有写过。也许发布一个新问题,看看是否有人可以详细说明,因为我不确定我是否有时间在新的未来写一个
@JoãoDaniel 我做了一个。以上是关于ctree() - 如何获取每个终端节点的拆分条件列表?的主要内容,如果未能解决你的问题,请参考以下文章
如何获取所有终端节点 - r 中的权重和响应预测“ctree”