Tidymodels 包:使用 ggplot() 可视化随机森林模型以显示最重要的预测变量

Posted

技术标签:

【中文标题】Tidymodels 包:使用 ggplot() 可视化随机森林模型以显示最重要的预测变量【英文标题】:Tidymodels Package: Visualising a random forest model using ggplot() to show the most important predictors 【发布时间】:2021-03-04 10:58:36 【问题描述】:

概述

我正在按照教程(见下文)从袋装树、随机森林、提升树和一般线性模型中找到最佳拟合模型。

教程(参见下面的示例)

https://bcullen.rbind.io/post/2020-06-02-tidymodels-decision-tree-learning-in-r/

问题

在这种情况下,我想进一步探索数据,并为 随机森林模型中的数据可视化最重要的预测变量(见下图)。

我的数据框称为 FID随机森林模型中的预测变量涉及:

    年份(数字) 月(因子) 天数(数字)

因变量是频率(数字)

当我尝试运行绘图以可视化最重要的预测变量时,我不断收到以下错误消息:-

Error: Problem with `mutate()` input `oob_rmse`.
x non-numeric argument to mathematical function
ℹ Input `oob_rmse` is `map_dbl(fit, ~sqrt(.x$prediction.error))`.
Run `rlang::last_error()` to see where the error occurred.
Called from: signal_abort(cnd)

如果有人对如何修复错误消息有任何建议,我将不胜感激。

在此先感谢

如何从教程中的 R 代码生成绘图的示例

可视化模型

绘制以显示教程中 R 代码中最重要的预测变量

我的 R 代码

##Open libraries
library(tidymodels)
library(parsnip)
library(forcats)
library(ranger)
library(baguette)

###########################################################
#split this single dataset into two: a training set and a testing set
data_split <- initial_split(FID)
# Create data frames for the two sets:
train_data <- training(data_split)
test_data  <- testing(data_split)

 # resample the data with 10-fold cross-validation (10-fold by default)
  cv <- vfold_cv(train_data, v=3)
###########################################################

##Produce the recipe

rec <- recipe(Frequency ~ ., data = FID) %>% 
          step_nzv(all_predictors(), freq_cut = 0, unique_cut = 0) %>% # remove variables with zero variances
          step_novel(all_nominal()) %>% # prepares test data to handle previously unseen factor levels 
          step_medianimpute(all_numeric(), -all_outcomes(), -has_role("id vars"))  %>% # replaces missing numeric observations with the median
          step_dummy(all_nominal(), -has_role("id vars")) # dummy codes categorical variables

###################################################################################


    ###################################################
    ##Random forests
    ###################################################
    
    mod_rf <-rand_forest(trees = 1e3) %>%
                                  set_engine("ranger",
                                  num.threads = parallel::detectCores(), 
                                  importance = "permutation", 
                                  verbose = TRUE) %>% 
                                  set_mode("regression") 
                                  
    ##Create Workflow
    
    wflow_rf <- workflow() %>% 
                   add_model(mod_rf) %>% 
                         add_recipe(rec)
    
    ##Fit the model
    
    plan(multisession)
    
    fit_rf<-fit_resamples(
                 wflow_rf,
                 cv,
                 metrics = metric_set(rmse, rsq),
                 control = control_resamples(save_pred = TRUE,
                                             extract = function(x) extract_model(x)))
    
    
    # extract roots
    rf_tree_roots <- function(x)
                         map_chr(1:1000, 
                            ~ranger::treeInfo(x, tree = .)[1, "splitvarName"])
                                
    
    rf_roots <- function(x)
                           x %>% 
                            dplyr::select(.extracts) %>% 
                            unnest(cols = c(.extracts)) %>% 
                            dplyr::mutate(fit = map(.extracts,
                            ~.x$fit$fit$fit),
                            oob_rmse = map_dbl(fit,
                                  ~sqrt(.x$prediction.error)),
                             roots = map(fit, 
                            ~rf_tree_roots(.))
                                   ) %>% 
                            dplyr::select(roots) %>% 
                            unnest(cols = c(roots))
                            
    
    ##Open a plotting window
    dev.new()
    
    # plot
    rf_roots(fit_rf) %>% 
                group_by(roots) %>% 
                count() %>% 
                dplyr::arrange(desc(n)) %>% 
                dplyr::filter(n > 75) %>% 
                ggplot(aes(fct_reorder(roots, n), n)) +
                geom_col() + 
                coord_flip() + 
                labs(x = "root", y = "count")

##Error message

Error: Problem with `mutate()` input `oob_rmse`.
x non-numeric argument to mathematical function
ℹ Input `oob_rmse` is `map_dbl(fit, ~sqrt(.x$prediction.error))`.
Run `rlang::last_error()` to see where the error occurred.
Called from: signal_abort(cnd)

数据框 - FID

  structure(list(Year = c(2015, 2015, 2015, 2015, 2015, 2015, 2015, 
2015, 2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016, 2016, 2016, 
2016, 2016, 2016, 2016, 2016, 2016, 2017, 2017, 2017, 2017, 2017, 
2017, 2017, 2017, 2017, 2017, 2017, 2017), Month = structure(c(1L, 
2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 1L, 2L, 3L, 4L, 
5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 
8L, 9L, 10L, 11L, 12L), .Label = c("January", "February", "March", 
"April", "May", "June", "July", "August", "September", "October", 
"November", "December"), class = "factor"), Frequency = c(36, 
28, 39, 46, 5, 0, 0, 22, 10, 15, 8, 33, 33, 29, 31, 23, 8, 9, 
7, 40, 41, 41, 30, 30, 44, 37, 41, 42, 20, 0, 7, 27, 35, 27, 
43, 38), Days = c(31, 28, 31, 30, 6, 0, 0, 29, 15, 
29, 29, 31, 31, 29, 30, 30, 7, 0, 7, 30, 30, 31, 30, 27, 31, 
28, 30, 30, 21, 0, 7, 26, 29, 27, 29, 29)), row.names = c(NA, 
-36L), class = "data.frame")

【问题讨论】:

如果代码是可重现的将会很有帮助。例如,您没有提及所需的软件包 (library(...))。此外,在recipe() 中,您提到了一个变量Frequency_Blue,它不属于您的数据集。 对不起,塞巴斯蒂安,那里有一个复制和粘贴错字。已更正! 【参考方案1】:

如果您查看包含所有模型的 tibble,它不会正确提取错误:

fit_rf$.extracts
[[1]]
# A tibble: 1 x 1
  .extracts
  <list>   
1 <ranger> 

它嵌入在一个列表或列表中,但没有名称:

names(fit_rf$.extracts[[1]][[1]])
NULL

因此这部分会失败:

map(fit_rf$.extracts,~.x$fit$fit$fit)

如果你看一下第一个 unnest 之后的结构,这已经很合适了:

fit_rf %>% dplyr::select(.extracts) %>% unnest(cols = c(.extracts)) 
# A tibble: 3 x 1
  .extracts
  <list>   
1 <ranger> 
2 <ranger> 
3 <ranger> 

所以我们可以这样做:

rf_roots <- function(x)
                       x %>% 
                       select(.extracts) %>% 
                       unnest(cols = c(.extracts)) %>% 
                       mutate(oob_rmse = map_dbl(.extracts,
                                  ~sqrt(.x$prediction.error)),
                              roots = map(.extracts, 
                                  ~rf_tree_roots(.))
                               ) %>% 
                        dplyr::select(roots) %>% 
                        unnest(cols = c(roots))
                        

现在可以使用了:

rf_roots(fit_rf)
# A tibble: 3,000 x 1
   roots          
   <chr>          
 1 Month_August   
 2 Year           
 3 Month_July     
 4 Month_September
 5 Month_December 
 6 Month_March    
 7 Month_July     
 8 Month_September
 9 Month_December 
10 Days        

附加组件:如果目标是获取每个模型中每棵树的根变量,可以简单地这样做:

root_vars = unnest(fit_rf,.extracts) %>% 
pull(.extracts) %>% 
map(rf_tree_roots)

或者在基础 R 中:

lapply(fit_rf$.extracts,function(i)rf_tree_roots(i[[1]][[1]]))

您可以轻松地将其取消列出以制作条形图。

【讨论】:

谢谢笨狼!我很高兴得到您的帮助,对此深表感谢。我想通过精确定位最适合的模型来进行模型比较过程。在本教程(上面的链接)中,作者使用 rsme 值来确定具有最低 rsme 值的最佳拟合模型。我还在学习中 出于兴趣,你会用什么方法? 嗨@AliceHobbs,对不起,我应该进一步解释,在rf_roots 中,rmse 和根都被提取,但只返回根,使前者变得多余 看完帖子,rmse在bcullen.rbind.io/post/…使用但与此功能无关 我只是想更多的想法可以进入代码,或者数据是如何存储的。可以避免很多 map、select、unnest。很可能某些部分将来会引发错误,dplyr 和 tidyr 仍在不断发展。一旦熟悉了它在做什么,重写一些代码是有意义的

以上是关于Tidymodels 包:使用 ggplot() 可视化随机森林模型以显示最重要的预测变量的主要内容,如果未能解决你的问题,请参考以下文章

将 tidymodels 拟合模型应用于新的未标记数据

使用 step_naomit 进行预测并使用 tidymodels 保留 ID

如何将经过训练和测试的随机森林模型应用于 tidymodels 中的新数据集?

Tidymodels:在 R 中进行 10 倍交叉验证后,从 TIbble 中取消最佳拟合模型的 RMSE 和 RSQ 值

R语言ggplot2可视化:使用patchwork包绘制ggplot2可视化结果的组合图(自定义图像的嵌入关系)使用patchwork包绘制ggplot2可视化结果的组合图(自定义组合形式)

R使用pROC和ggplot2包绘制ROC曲线