在使用插入符号的 train() 使用公式训练的 randomForest 对象上使用 predict() 时出错
Posted
技术标签:
【中文标题】在使用插入符号的 train() 使用公式训练的 randomForest 对象上使用 predict() 时出错【英文标题】:Error when using predict() on a randomForest object trained with caret's train() using formula 【发布时间】:2015-07-17 19:53:51 【问题描述】:在 64 位 Linux 机器上使用带有插入符号 6.0-41 和 randomForest 4.6-10 的 R 3.2.0。
尝试对使用公式通过 caret
包中的 train()
函数训练的 randomForest
对象使用 predict()
方法时,该函数返回错误。
通过randomForest()
和/或使用x=
和y=
而不是公式进行训练时,一切运行顺利。
这是一个工作示例:
library(randomForest)
library(caret)
data(imports85)
imp85 <- imports85[, c("stroke", "price", "fuelType", "numOfDoors")]
imp85 <- imp85[complete.cases(imp85), ]
imp85[] <- lapply(imp85, function(x) if (is.factor(x)) x[,drop=TRUE] else x) ## Drop empty levels for factors.
modRf1 <- randomForest(numOfDoors~., data=imp85)
caretRf <- train( numOfDoors~., data=imp85, method = "rf" )
modRf2 <- caretRf$finalModel
modRf3 <- randomForest(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"])
caretRf <- train(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"], method = "rf")
modRf4 <- caretRf$finalModel
p1 <- predict(modRf1, newdata=imp85)
p2 <- predict(modRf2, newdata=imp85)
p3 <- predict(modRf3, newdata=imp85)
p4 <- predict(modRf4, newdata=imp85)
最后4行中,只有第二行p2 <- predict(modRf2, newdata=imp85)
返回如下错误:
Error in predict.randomForest(modRf2, newdata = imp85) :
variables in the training data missing in newdata
看来这个错误的原因是predict.randomForest
方法使用rownames(object$importance)
来确定用于训练随机森林object
的变量的名称。看的时候
rownames(modRf1$importance)
rownames(modRf2$importance)
rownames(modRf3$importance)
rownames(modRf4$importance)
我们看到了:
[1] "stroke" "price" "fuelType"
[1] "stroke" "price" "fuelTypegas"
[1] "stroke" "price" "fuelType"
[1] "stroke" "price" "fuelType"
因此,不知何故,当使用带有公式的caret
train()
函数时,会更改randomForest
对象的importance
字段中的(因子)变量的名称。
插入符号train()
函数的公式和非公式版本之间真的不一致吗?还是我错过了什么?
【问题讨论】:
modRf3 <- randomForest(x=dataTrain[,c("stroke", "price", "fuelType")], y=dataTrain[, "numOfDoors"], data=imp85) Error in randomForest(x = dataTrain[, c("stroke", "price", "fuelType")], : object 'dataTrain' not found
正如所指出的,您没有在示例中定义dataTrain
,这意味着问题不是reproducible。如果我们无法运行代码并获得与您相同的结果,那么帮助您并不容易。
我的错,dataTrain
应该是imp85
,我编辑了原始问题中的代码。我还在调用中删除了选项data=imp85
,其中明确提到了x
和y
,因为它没有用处。
【参考方案1】:
首先,几乎从不使用$finalModel
对象进行预测。使用predict.train
。这是一个很好的例子。
某些函数(包括randomForest
和train
)处理虚拟变量的方式存在一些不一致。 R 中使用公式方法的大多数函数会将因子预测变量转换为虚拟变量,因为它们的模型需要数据的数字表示。例外情况是基于树和基于规则的模型(可以根据分类预测变量进行拆分)、朴素贝叶斯和其他一些模型。
所以randomForest
在您使用randomForest(y ~ ., data = dat)
时不会创建虚拟变量,但train
(和大多数其他人)将使用像train(y ~ ., data = dat)
这样的调用。
发生错误是因为fuelType
是一个因素。 train
创建的虚拟变量名称不同,因此 predict.randomForest
找不到它们。
对train
使用非公式方法会将因子预测变量传递给randomForest
,一切都会正常进行。
TL;DR
如果您想要相同的级别或使用predict.train
train
的非公式方法
最大
【讨论】:
不幸的是,我没有足够的声誉来支持您的答案,但您完美地回答了我的问题。我一直想知道所有那些允许使用公式的函数,如果函数调用的公式和非公式版本之间的数据处理方式存在差异。现在我明白了!对于$finalModel
的使用,我同意使用它通常不是一个好主意。这里我只是想比较caret
和randomForest
方法的结果。【参考方案2】:
您收到此错误的原因可能有两个。
1.训练集和测试集中的分类变量的类别不匹配。要检查这一点,您可以运行以下内容。
嗯,首先,将自变量/特征保存在一个列表中是一种很好的做法。假设该列表是“vars”。并且说,您将“数据”分为“训练”和“测试”。走吧:
for (v in vars) if (class(Data[,v]) == 'factor') print(v) # print(levels(Train[,v])) # print(levels(Test[,v])) print(all.equal(levels(Train[,v]) , levels(Test[,v])))
找到不匹配的分类变量后,您可以返回,将测试数据的类别强加到训练数据上,然后重新构建模型。在类似于上面的循环中,对于每个 nonMatchingVar,你可以这样做
levels(Test$nonMatchingVar) <- levels(Train$nonMatchingVar)
2. 一个愚蠢的。如果您不小心将因变量留在了自变量集中,您可能会遇到此错误消息。我犯了那个错误。解决方案:小心一点。
【讨论】:
【参考方案3】:另一种方法是使用model.matrix
对测试数据进行显式编码,例如
p2 <- predict(modRf2, newdata=model.matrix(~., imp85))
【讨论】:
【参考方案4】:这不是您问题的答案,但我相信它会帮助其他人,因为它帮助了我。如果您在训练数据列中使用的测试数据列中缺少任何 NA,则预测将不起作用。您需要先估算这些值。
【讨论】:
以上是关于在使用插入符号的 train() 使用公式训练的 randomForest 对象上使用 predict() 时出错的主要内容,如果未能解决你的问题,请参考以下文章
警告消息:使用 rpart 的插入符号 train() 中的“重采样性能测量中的缺失值”