无法在 caret 包中为 extraTrees 模型指定概率函数
Posted
技术标签:
【中文标题】无法在 caret 包中为 extraTrees 模型指定概率函数【英文标题】:Cannot specify probability function for extraTrees model in caret package 【发布时间】:2015-08-16 17:18:21 【问题描述】:大家,
最近,我一直在 caret 包中使用 extraTrees 模型。但是,我注意到通过使用以下脚本将 extraTrees 模型的概率函数设置为 NULL:
extratrees_para <- getModelInfo('extraTrees', regex = F)[[1]]
extratrees_para$prob
我注意到在extraTress的原始包中,它可以用于生成分类问题的概率预测。所以我想为extratrees_para 指定prob 函数。
extratrees_para$prob <- function(modelFit, newdata, submodels = NULL)
as.data.frame(predict(modelFit, newdata, probability = TRUE))
extratrees_para$type <- 'Classification'
然后我构造训练函数来构建模型
extratreesGrid <- expand.grid(.mtry=1:2,
.numRandomCuts=1)
modelfit_extratrees <- train(outcome~., data=training_scaled_sel,
method = extratrees_para,
metric = "ROC",
trControl = trainControl(method = 'repeatedcv',
repeats=1,
classProb = T,
summaryFunction = twoClassSummary),
ntree = 3000,
tuneGrid = extratreesGrid)
但是,我不断收到这个信息量不大的错误消息
"train.default(x, y, weights = w, ...) 中的错误: 最终调整参数无法确定 另外:警告信息: 1: 在nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, : 重新抽样的绩效指标中存在缺失值。 2:在 train.default(x, y, weights = w, ...) 中: 在汇总结果中发现缺失值"
以下是我的会话信息。如果有人可以帮助我,我将不胜感激。谢谢!
sessioninfo()
R version 3.1.2 (2014-10-31)
Platform: x86_64-w64-mingw32/x64 (64-bit)
locale:
[1] LC_COLLATE=English_United States.1252
[2] LC_CTYPE=English_United States.1252
[3] LC_MONETARY=English_United States.1252
[4] LC_NUMERIC=C
[5] LC_TIME=English_United States.1252
attached base packages:
[1] grid stats graphics grDevices utils datasets methods
[8] base
other attached packages:
[1] DMwR_0.4.1 BiocInstaller_1.16.5 caret_6.0-41
[4] ggplot2_1.0.0 lattice_0.20-29 extraTrees_1.0.5
[7] rJava_0.9-6
loaded via a namespace (and not attached):
[1] abind_1.4-3 bitops_1.0-6 BradleyTerry2_1.0-5
[4] brglm_0.5-9 car_2.0-24 caTools_1.17.1
[7] class_7.3-11 codetools_0.2-9 colorspace_1.2-4
[10] compiler_3.1.2 digest_0.6.8 e1071_1.6-4
[13] foreach_1.4.2 gdata_2.16.1 gplots_2.17.0
[16] gtable_0.1.2 gtools_3.4.1 iterators_1.0.7
[19] KernSmooth_2.23-13 lme4_1.1-7 MASS_7.3-35
[22] Matrix_1.1-4 mgcv_1.8-3 minqa_1.2.4
[25] munsell_0.4.2 nlme_3.1-118 nloptr_1.0.4
[28] nnet_7.3-8 parallel_3.1.2 pbkrtest_0.4-2
[31] plyr_1.8.1 pROC_1.8 proto_0.3-10
[34] quantmod_0.4-4 quantreg_5.11 Rcpp_0.11.4
[37] reshape2_1.4.1 ROCR_1.0-7 rpart_4.1-8
[40] scales_0.2.4 SparseM_1.6 splines_3.1.2
[43] stringr_0.6.2 tools_3.1.2 TTR_0.22-0
[46] xts_0.9-7 zoo_1.7-12
【问题讨论】:
【参考方案1】:当我第一次添加模型时,我认为它不会生成类概率。我不确定为什么您的版本不起作用,但这是我要添加到包中的内容:
modelInfo <- list(label = "Random Forest by Randomization",
library = c("extraTrees"),
loop = NULL,
type = c('Regression', 'Classification'),
parameters = data.frame(parameter = c('mtry', 'numRandomCuts'),
class = c('numeric', 'numeric'),
label = c('# Randomly Selected Predictors', '# Random Cuts')),
grid = function(x, y, len = NULL)
expand.grid(mtry = var_seq(p = ncol(x),
classification = is.factor(y),
len = len),
numRandomCuts = 1:len)
,
fit = function(x, y, wts, param, lev, last, classProbs, ...)
extraTrees(x, y, mtry = param$mtry, numRandomCuts = param$numRandomCuts, ...),
predict = function(modelFit, newdata, submodels = NULL)
predict(modelFit, newdata),
prob = function(modelFit, newdata, submodels = NULL)
predict(modelFit, newdata, probability = TRUE),
levels = function(x) x$obsLevels,
tags = c("Random Forest", "Ensemble Model", "Bagging", "Implicit Feature Selection"),
sort = function(x) x[order(x[,1]),])
【讨论】:
非常感谢软件包开发人员如此迅速地回复我!我想我的模型设计和你一样。但我仍然无法让它工作(我今天早上再试一次)。 另一个问题。我总是想知道是否能找到关于你们包装设计的描述。我知道每个模型都有自己定义的模块,例如“label”、“library”、“fit”、“predict”等。但是我怎样才能更好地理解插入符号的工作原理呢?就像在我构建模型并调用训练函数之后一样。哪些函数会被顺序调用?据我了解,在调用 train 函数后,1. 预处理步骤 2. 将按指定生成随机重采样 3.将使用每个模型的拟合函数为循环变量中的模型训练模型 4. 将计算指定的指标并确定最终模型参数 5. 将使用最终模型构建最终模型。 原来我有旧版本的 caret 和 extraTrees。我更新了这两个包,现在一切正常。以上是关于无法在 caret 包中为 extraTrees 模型指定概率函数的主要内容,如果未能解决你的问题,请参考以下文章
在 R 代码中使用“caret”包中的 preProcess 的目的是啥?
“xgboost”官方包与 R 中“caret”包中的 xgboost 的不同结果
R语言:利用caret包中的dummyVars函数进行虚拟变量处理
R语言使用caret包中的createFolds函数对机器学习数据集进行交叉验证抽样返回的样本列表长度为k个
R语言使用caret包中的createResample函数进行机器学习数据集采样数据集有放回的采样(bootstrapping)