从决策树进行预测的高效算法(使用 R)

Posted

技术标签:

【中文标题】从决策树进行预测的高效算法(使用 R)【英文标题】:Efficient algorithm for predicting from a decision tree (using R) 【发布时间】:2016-03-11 08:06:03 【问题描述】:

我正在修改 Brieman 的随机森林程序(我不懂 C/C++),所以我在 R 中从头开始编写了我自己的 RF 变体。我的程序和标准程序之间的区别基本上只是如何计算拆分点和终端节点中的值——一旦我在森林中有一棵树,就可以认为它与典型 RF 算法中的树非常相似。

我的问题是它的预测速度很慢,而且我很难想办法让它更快。

一个测试树对象链接here,一些测试数据链接here。你可以直接下载它,或者如果你安装了repmis,你可以在下面加载它。它们被称为testtreesampx

library(repmis)
testtree <- source_DropboxData(file = "testtree", key = "sfbmojc394cnae8")
sampx <- source_DropboxData(file = "sampx", key = "r9imf317hpflpsx")

编辑:不知何故,我还没有真正学习如何很好地使用 github。我已将所需文件上传到存储库 here - 抱歉,我目前无法弄清楚如何获取永久链接...

看起来像这样(使用我编写的绘图函数):

这里有一点关于对象的结构:

1> summary(testtree)
         Length Class      Mode   
nodes       7   -none-     list   
minsplit    1   -none-     numeric
X          29   data.frame list   
y        6719   -none-     numeric
weights  6719   -none-     numeric
oob      2158   -none-     numeric
1> summary(testtree$nodes)
     Length Class  Mode
[1,] 4      -none- list
[2,] 8      -none- list
[3,] 8      -none- list
[4,] 7      -none- list
[5,] 7      -none- list
[6,] 7      -none- list
[7,] 7      -none- list
1> summary(testtree$nodes[[1]])
         Length Class  Mode   
y        6719   -none- numeric
output         1   -none- numeric
Terminal    1   -none- logical
children    2   -none- numeric
1> testtree$nodes[[1]][2:4]
$output
[1] 40.66925

$Terminal
[1] FALSE

$children
[1] 2 3

1> summary(testtree$nodes[[2]])
           Length Class  Mode     
y          2182   -none- numeric  
parent        1   -none- numeric  
splitvar      1   -none- character
splitpoint    1   -none- numeric  
handedness    1   -none- character
children      2   -none- numeric  
output        1   -none- numeric  
Terminal      1   -none- logical  
1> testtree$nodes[[2]][2:8]
$parent
[1] 1

$splitvar
[1] "bizrev_allHH"

$splitpoint
    25% 
788.875 

$handedness
[1] "Left"

$children
[1] 4 5

$output
[1] 287.0085

$Terminal
[1] FALSE

output 是该节点的返回值——我希望其他一切都是不言自明的。

我写的预测函数有效,但是太慢了。基本上它“走下树”,通过观察观察:

predict.NT = function(tree.obj, newdata=NULL)
    if (is.null(newdata))X = tree.obj$X else X = newdata
    tree = tree.obj$nodes
    if (length(tree)==1)#Return the mean for a stump
        return(rep(tree[[1]]$output,length(X)))
    
    pred = apply(X = newdata, 1, godowntree, nn=1, tree=tree)
    return(pred)


godowntree = function(x, tree, nn = 1)
    while (tree[[nn]]$Terminal == FALSE)
        fb = tree[[nn]]$children[1]
        sv = tree[[fb]]$splitvar
        sp = tree[[fb]]$splitpoint
        if (class(sp)=='factor')
            if (as.character(x[names(x) == sv]) == sp)
                nn<-fb
             else
                nn<-fb+1
            
         else 
            if (as.character(x[names(x) == sv]) < sp)
                nn<-fb
             else
                nn<-fb+1
            
        
    
    return(tree[[nn]]$output)

问题在于它真的很慢(当您考虑到非样本树更大,并且我需要这样做很多很多次时),即使对于一棵简单的树:

library(microbenchmark)
microbenchmark(predict.NT(testtree,sampx))
Unit: milliseconds
                        expr      min       lq     mean   median       uq
 predict.NT(testtree, sampx) 16.19845 16.36351 17.37022 16.54396 17.07274
     max neval
 40.4691   100

我今天从某人那里得到一个想法,我可以编写一个函数工厂类型的函数(即:一个生成闭包的函数,我刚刚学习)将我的树分解成一堆嵌套的 if/else陈述。然后我可以通过它发送数据,这可能比一遍又一遍地从树中提取数据更快。我还没有编写函数函数生成函数,但是我手写了我从中得到的那种输出,并测试了它:

predictif = function(x)
    if (x[names(x) == 'bizrev_allHH'] < 788.875)
        if (x[names(x) == 'male_head'] <.872)
            return(548)
         else 
            return(165)
        
     else 
        if (x[names(x) == 'nondurable_exp_mo'] < 4190.965)
            return(-283)
        else
            return(-11.4)
        
    

predictif.NT = function(tree.obj, newdata=NULL)
    if (is.null(newdata))X = tree.obj$X else X = newdata
    tree = tree.obj$nodes
    if (length(tree)==1)#Return the mean for a stump
        return(rep(tree[[1]]$output,length(X)))
    
    pred = apply(X = newdata, 1, predictif)
    return(pred)


microbenchmark(predictif.NT(testtree,sampx))
Unit: milliseconds
                          expr      min       lq     mean   median       uq
 predictif.CT(testtree, sampx) 12.77701 12.97551 14.21417 13.18939 13.67667
      max neval
 30.48373   100

快一点,但不多!

如果有任何想法可以加快速度,我将不胜感激!或者,如果答案是“如果不将其转换为 C/C++,你真的无法获得这么快的速度”,那也是有价值的信息(尤其是如果你给了我一些关于为什么会这样的信息)。

虽然我当然很欣赏 R 中的答案,但伪代码中的答案也会很有帮助。

谢谢!

【问题讨论】:

我在从 Dropbox 下载对象时遇到问题。您能否在您的问题中分享dput(testtree) 的结果? 我刚刚尝试了dput(testree),而且它很大。让我想出一个更好的方法来链接数据...... 也许您找到了一种将计算值存储在静态字典中的方法,例如缓存。在计算新值之前查看字典。 顺便说一句,到目前为止,最简单的方法是同时预先计算所有数据点的比较。很容易做到。一旦我有了你的数据,我会向你展示,但基础知识从 transform(sampx, n1 = bizrev_allHH &lt; 788.875, n2 = male_head &lt; .872) 等开始。基于此,它可以非常快速(无需使用 C 或 C++),并且具有可以对任何决策树进行一些工作。 @DavidRobinson 我已将文件添加到 github 存储库:github.com/mynameisnotdrew/test/tree/… 感谢您的建议,并提前感谢您的演示! 【参考方案1】:

加速函数的秘诀是矢量化。不要单独对每一行执行所有操作,而是一次对所有行执行这些操作。

让我们重新考虑您的 predictif 函数

predictif = function(x)
    if (x[names(x) == 'bizrev_allHH'] < 788.875)
        if (x[names(x) == 'male_head'] <.872)
            return(548)
         else 
            return(165)
        
     else 
        if (x[names(x) == 'nondurable_exp_mo'] < 4190.965)
            return(-283)
        else
            return(-11.4)
        
    

这是一种缓慢的方法,因为它将所有这些操作应用于每个单独的实例。函数调用、if 语句,尤其是像 names(x) == 'bizrev_allHH' 这样的操作都有一些开销,当您为每个实例执行这些操作时,这些开销都会增加。

相比之下,简单地比较两个数字非常快!因此,请编写上述的矢量化版本。

predictif_fast <- function(newdata) 
  n1 <- newdata$bizrev_allHH < 788.875
  n2 <- newdata$male_head < .872
  n3 <- newdata$nondurable_exp_mo < 4190.965

  ifelse(n1, ifelse(n2, 548.55893, 165.15537),
             ifelse(n3, -283.35145, -11.40185))

注意,这一点非常重要,这个函数没有被传递给一个实例。它旨在传递您的整个新数据。这是因为&lt;ifelse 操作都是向量化的:当给定一个向量时,它们返回一个向量。

让我们比较一下你的函数和这个新函数:

> microbenchmark(predictif.NT(testtree, sampx),
                 predictif_fast(sampx))
Unit: microseconds
                          expr       min         lq     mean    median         uq
 predictif.NT(testtree, sampx) 12106.419 13144.2390 14684.46 13719.406 14593.1565
         predictif_fast(sampx)   189.093   213.6505   263.74   246.192   260.7895
       max neval cld
 79136.335   100   b
  2344.059   100  a 

请注意,我们通过矢量化获得了 50 倍的加速。

顺便说一句,它可以大大加快速度(如果你对索引很聪明,还有更快的替代ifelse),但总体上从“对每一行执行一个函数”切换到“对整个向量执行操作” " 让您获得最大的加速。


这并不能完全解决您的问题,因为您需要在一般树上执行这些矢量化操作,而不仅仅是在这个特定的树上。我不会为您解决通用版本,但请考虑您可以重写您的godowntree 函数,以便它获取整个数据帧并在整个数据帧上执行其操作,而不仅仅是一个。然后,不要使用if 分支,而是保留每个实例当前所在的子节点的向量。

【讨论】:

谢谢!在你把它放在一起和你之前发表评论之间,我写了一个基于评论中的想法快约 10 倍的函数,但你这里的明显优于那个。从这里编写一个更通用的函数很简单。再次感谢。

以上是关于从决策树进行预测的高效算法(使用 R)的主要内容,如果未能解决你的问题,请参考以下文章

GBDT 算法:原理篇

根据决策树算法生成的模型进行预测

微软数据挖掘算法:Microsoft 决策树分析算法

机器学习算法决策树-6 PRISM

分类算法——决策树

R语言机器学习 | 8 决策树与集成学习