从随机森林中提取一棵树,然后使用提取的树进行预测

Posted

技术标签:

【中文标题】从随机森林中提取一棵树,然后使用提取的树进行预测【英文标题】:Extract a tree from a random forest and then use the extracted tree for prediction 【发布时间】:2020-09-02 17:05:24 【问题描述】:

我们以鸢尾花数据集为例。

library(randomForest)
data(iris)
smp_size <- floor(0.75 * nrow(iris))
train_ind <- sample(seq_len(nrow(iris)), size = smp_size)

train <- iris[train_ind, ]
test <- iris[-train_ind, ]

model <- randomForest(Species~., data = train, ntree=10)

如果我使用 randomForest 包中的 getTree() 函数,我可以毫无问题地提取例如第三棵树。

treefit <- getTree(model, 3)

但是,例如,我如何使用它(即 treefit)对测试集进行预测?像“predict()”,有没有一个函数可以直接做到这一点?

提前谢谢你

【问题讨论】:

如果你真的想使用那棵树,你将不得不使用底层的 c 代码,github.com/cran/randomForest/blob/master/R/… .. 我认为使用 predict.all 的答案是一个很好的解决方法..跨度> 【参考方案1】:

您可以通过将predict.all 参数设置为TRUE 来直接使用randomForest 包中的predict 函数。

请参阅以下可重现的代码以了解如何使用它:另请参阅predict.randomForest here 的帮助页面。

library(randomForest)
set.seed(1212)
x <- rnorm(100)
y <- rnorm(100, x, 10)
df_train <- data.frame(x=x, y=y)
x_test <- rnorm(20)
y_test <- rnorm(20, x_test, 10)
df_test <- data.frame(x = x_test, y = y_test)
rf_fit <- randomForest(y ~ x, data = df_train, ntree = 500)
# You get a list with the overall predictions and individual tree predictions
rf_pred <- predict(rf_fit, df_test, predict.all = TRUE)
rf_pred$individual[, 3] # Obtains the 3rd tree's predictions on the test data

【讨论】:

以上是关于从随机森林中提取一棵树,然后使用提取的树进行预测的主要内容,如果未能解决你的问题,请参考以下文章

有没有办法从随机森林模型中提取树深度?

随机森林回归中的树数

随机森林

从决策树进行预测的高效算法(使用 R)

PySpark 和 MLLib:随机森林预测的类概率

随机森林预测