在使用插入符号的 train() 使用公式训练的 randomForest 对象上使用 predict() 时出错

Posted

技术标签:

【中文标题】在使用插入符号的 train() 使用公式训练的 randomForest 对象上使用 predict() 时出错【英文标题】:Error when using predict() on a randomForest object trained with caret's train() using formula 【发布时间】:2015-07-17 19:53:51 【问题描述】:

在 64 位 Linux 机器上使用带有插入符号 6.0-41 和 randomForest 4.6-10 的 R 3.2.0。

尝试对使用公式通过 caret 包中的 train() 函数训练的 randomForest 对象使用 predict() 方法时,该函数返回错误。 通过randomForest() 和/或使用x=y= 而不是公式进行训练时,一切运行顺利。

这是一个工作示例:

library(randomForest)
library(caret)

data(imports85)
imp85     <- imports85[, c("stroke", "price", "fuelType", "numOfDoors")]
imp85     <- imp85[complete.cases(imp85), ]
imp85[]   <- lapply(imp85, function(x) if (is.factor(x)) x[,drop=TRUE] else x) ## Drop empty levels for factors.

modRf1  <- randomForest(numOfDoors~., data=imp85)
caretRf <- train( numOfDoors~., data=imp85, method = "rf" )
modRf2  <- caretRf$finalModel
modRf3  <- randomForest(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"])
caretRf <- train(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"], method = "rf")
modRf4  <- caretRf$finalModel

p1      <- predict(modRf1, newdata=imp85)
p2      <- predict(modRf2, newdata=imp85)
p3      <- predict(modRf3, newdata=imp85)
p4      <- predict(modRf4, newdata=imp85)

最后4行中,只有第二行p2 &lt;- predict(modRf2, newdata=imp85)返回如下错误:

Error in predict.randomForest(modRf2, newdata = imp85) : 
variables in the training data missing in newdata

看来这个错误的原因是predict.randomForest方法使用rownames(object$importance)来确定用于训练随机森林object的变量的名称。看的时候

rownames(modRf1$importance)
rownames(modRf2$importance)
rownames(modRf3$importance)
rownames(modRf4$importance)

我们看到了:

[1] "stroke"   "price"    "fuelType"
[1] "stroke"   "price"    "fuelTypegas"
[1] "stroke"   "price"    "fuelType"
[1] "stroke"   "price"    "fuelType"

因此,不知何故,当使用带有公式的caret train() 函数时,会更改randomForest 对象的importance 字段中的(因子)变量的名称。

插入符号train()函数的公式和非公式版本之间真的不一致吗?还是我错过了什么?

【问题讨论】:

modRf3 &lt;- randomForest(x=dataTrain[,c("stroke", "price", "fuelType")], y=dataTrain[, "numOfDoors"], data=imp85) Error in randomForest(x = dataTrain[, c("stroke", "price", "fuelType")], : object 'dataTrain' not found 正如所指出的,您没有在示例中定义dataTrain,这意味着问题不是reproducible。如果我们无法运行代码并获得与您相同的结果,那么帮助您并不容易。 我的错,dataTrain 应该是imp85,我编辑了原始问题中的代码。我还在调用中删除了选项data=imp85,其中明确提到了xy,因为它没有用处。 【参考方案1】:

首先,几乎从不使用$finalModel 对象进行预测。使用predict.train。这是一个很好的例子。

某些函数(包括randomForesttrain)处理虚拟变量的方式存在一些不一致。 R 中使用公式方法的大多数函数会将因子预测变量转换为虚拟变量,因为它们的模型需要数据的数字表示。例外情况是基于树和基于规则的模型(可以根据分类预测变量进行拆分)、朴素贝叶斯和其他一些模型。

所以randomForest 在您使用randomForest(y ~ ., data = dat)不会创建虚拟变量,但train(和大多数其他人)将使用像train(y ~ ., data = dat) 这样的调用。

发生错误是因为fuelType 是一个因素。 train 创建的虚拟变量名称不同,因此 predict.randomForest 找不到它们。

train 使用非公式方法会将因子预测变量传递给randomForest,一切都会正常进行。

TL;DR

如果您想要相同的级别使用predict.train

,请使用train 的非公式方法

最大

【讨论】:

不幸的是,我没有足够的声誉来支持您的答案,但您完美地回答了我的问题。我一直想知道所有那些允许使用公式的函数,如果函数调用的公式和非公式版本之间的数据处理方式存在差异。现在我明白了!对于$finalModel的使用,我同意使用它通常不是一个好主意。这里我只是想比较caretrandomForest方法的结果。【参考方案2】:

您收到此错误的原因可能有两个。

1.训练集和测试集中的分类变量的类别不匹配。要检查这一点,您可以运行以下内容。

嗯,首先,将自变量/特征保存在一个列表中是一种很好的做法。假设该列表是“vars”。并且说,您将“数据”分为“训练”和“测试”。走吧:

for (v in vars)
  if (class(Data[,v]) == 'factor')
    print(v)
    # print(levels(Train[,v])) 
    # print(levels(Test[,v]))
    print(all.equal(levels(Train[,v]) , levels(Test[,v])))
    

找到不匹配的分类变量后,您可以返回,将测试数据的类别强加到训练数据上,然后重新构建模型。在类似于上面的循环中,对于每个 nonMatchingVar,你可以这样做

levels(Test$nonMatchingVar) <- levels(Train$nonMatchingVar)

2. 一个愚蠢的。如果您不小心将因变量留在了自变量集中,您可能会遇到此错误消息。我犯了那个错误。解决方案:小心一点。

【讨论】:

【参考方案3】:

另一种方法是使用model.matrix 对测试数据进行显式编码,例如

p2 <- predict(modRf2, newdata=model.matrix(~., imp85))

【讨论】:

【参考方案4】:

这不是您问题的答案,但我相信它会帮助其他人,因为它帮助了我。如果您在训练数据列中使用的测试数据列中缺少任何 NA,则预测将不起作用。您需要先估算这些值。

【讨论】:

以上是关于在使用插入符号的 train() 使用公式训练的 randomForest 对象上使用 predict() 时出错的主要内容,如果未能解决你的问题,请参考以下文章

插入符号中训练数据的 ROC 曲线

使用并行训练带有插入符号的随机森林

在 R 中使用插入符号训练模型的时机

警告消息:使用 rpart 的插入符号 train() 中的“重采样性能测量中的缺失值”

在 R 中使用插入符号进行训练后,如何在 ROC 下计算 ROC 和 AUC?

插入符号训练二进制 glm 通过 doParallel 在并行集群上失败