R语言入门——多元回归交叉验证的实现
Posted 统计学小王子
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了R语言入门——多元回归交叉验证的实现相关的知识,希望对你有一定的参考价值。
目录
引言
随着模型复杂度的提高和数据量的提升,预测精度也会提高,但是使用传统的方法评估模型的精度,模型的泛化能力也会降低。我们这时候就有必要使用划分数据集的方法去评估模型。基本思想是参与评估的数据不去拟合模型。本文以多元回归为例子,构建函数去实现交叉验证建立CV模型评估。
下面是本文使用的三个包,大家可以自行载入:
# 载入包
library(ggplot2)
library(data.table)
library(MASS)
1、主要函数编写
1.1 随机数据的产生
– n:产生数据的观测次数
– p:产生数据的维数
– Beta:回归模型的真实值
为了降低信噪比,模型的残差标准差设为0.5.
> getData <- function(n = 200, p = 3, Beta = rep(1, p))
+ library(MASS)
+ x <- mvrnorm(n, rep(0, p), diag(rep(1, p)))
+ names(x) <- paste0('x', 1:p)
+ y <- x %*% Beta + rnorm(n, 0, 0.5)
+ Data <- data.frame(y = y, x)
+ Data
+
> head(getData())
y X1 X2 X3
1 -1.3952665 -0.7776920 -1.2173153 1.0370153
2 -0.3843078 0.1367314 0.7210690 -0.9921161
3 2.2212180 0.8984214 0.1888832 0.8709175
4 1.8005513 -0.2479999 1.1831627 0.9138255
5 -1.7861434 -0.8760386 -1.9289443 1.6050439
6 -3.4022036 -2.2106594 -1.3107755 -0.4465723
1.2 CV.lm的编写
– Data:模型数据
– k :交叉验证的折数(k = n是为留一交叉验证)
– p:每折的抽样占比,默认每折比例近似相同
– methed:cv和MCMC可选,默认CV
CV.lm <- function(data = Data, k = 5, p = rep(1, k)/k, method = 'CV')
n <- nrow(data)
e <- n - sum(round(n*p))
ngroup <- round(n*p)
if(e != 0)
for(i in 1:e)
ngroup[i] <- ngroup[i] + 1
start <- c(1, cumsum(ngroup)[1:(k-1)]+1)
end <- cumsum(ngroup)
ind <- sample(1:n)
if(method != 'CV')
testdata <- data[ind[start[1]:end[1]],]
traindata <- data[-ind[start[1]:end[1]],]
fit <- lm(y ~ ., data = traindata)
yhat <- predict(fit, newdata = testdata)
erro <- yhat - data$y[ind[start[1]:end[1]]]
RMSE <- sqrt((t(erro) %*% erro)/nrow(data))
out <- list(yhat = yhat, RMSE = RMSE)
return(out)
yhats <- c()
for(i in 1:k)
testdata <- data[ind[start[i]:end[i]],]
traindata <- data[-ind[start[i]:end[i]],]
fit <- lm(y ~ ., data = traindata)
yhat <- predict(fit, newdata = testdata)
yhats <- c(yhats, yhat)
erro <- yhats - data$y[ind]
RMSE <- sqrt((t(erro) %*% erro)/nrow(data))
out <- list(yhat = yhats, RMSE = RMSE)
return(out)
1.2 CV.lm的调用
> Data <- getData(n = 100, p = 20)
> CV.lm(Data, k = 5, method = 'CV')
$yhat
81 67 20 74 71 100 41 51 83 44 1 5 26
2.18654037 -0.01443100 9.87465172 -6.88415433 -4.84993135 -5.46583888 -6.24346588 -3.29852477 -0.91897172 5.83005552 -1.74781478 -9.20326282 -4.43044632
46 93 73 12 86 90 2 50 39 43 30 76 21
-2.52187250 -2.56797877 1.19273371 -8.09899265 0.12887602 4.43536269 -2.61761784 -0.26978052 -2.15083278 -7.51510079 2.79867447 -2.40063565 -0.19497446
79 49 25 45 16 85 89 3 58 65 32 61 84
-5.76782017 4.56694383 -0.21526687 3.25852907 -0.52974764 -3.83607700 1.62073269 3.08459944 0.02073593 2.62093830 -1.40076722 -5.52651415 -7.27654391
36 77 18 14 28 11 13 78 35 98 95 69 54
1.90159961 -1.33820465 -1.00851747 2.39342585 -1.44773031 -1.52923336 -7.20941753 -4.58400811 -0.33950637 -6.81580122 -3.43827376 1.83511295 0.49283565
82 42 59 10 19 15 4 33 91 62 60 22 38
-0.01923878 7.36222613 -0.48926252 -6.74970251 -9.11778698 1.94610032 1.74760592 -4.24623324 -1.76799507 1.18516389 1.79362303 -9.22521042 -8.64803652
34 87 68 72 37 6 63 80 31 53 88 70 56
1.96267748 -4.36196234 -2.64362333 -5.44288180 -4.34245648 0.25427637 7.74405120 3.41536736 9.13814499 -11.86642616 -7.39859153 8.01582381 -6.89951848
55 96 7 57 17 40 48 27 64 92 23 75 9
10.04581990 8.31775866 6.85081589 -10.29939349 0.75059299 -4.95233326 -1.36855339 2.56054428 6.77722209 1.09484657 -4.81534017 0.67397664 -7.86602131
24 8 94 29 99 52 47 66 97
-3.41131443 2.90677676 -7.04131687 -1.17112236 -2.90742471 -6.22888626 4.69440408 0.72042252 1.89888377
$RMSE
[,1]
[1,] 0.5477361
> CV.lm(Data, k = 100, method = 'CV')
$yhat
65 24 72 36 18 68 63 3 8 61 20 48 11
2.52588053 -3.47060450 -5.32799035 1.84975569 -1.12687832 -2.64406253 7.71981902 3.22563182 2.97051773 -5.33389902 10.18055762 -1.51795688 -1.69575435
89 67 78 58 79 5 51 9 34 75 83 86 37
1.69704665 -0.01853040 -4.66147804 0.09195376 -5.64092202 -9.04541722 -3.13346224 -7.82568027 2.10499678 0.71148817 -0.90057462 0.01139815 -4.03867105
17 45 35 32 50 73 77 46 4 74 26 31 40
0.88123507 3.14912166 -0.30819995 -1.35422218 -0.47778272 1.06719657 -1.37560765 -2.22359621 1.74974015 -7.09349240 -4.54148902 9.00315834 -4.93746637
1 21 47 42 28 62 91 52 57 12 22 70 6
-1.71513134 -0.08771302 4.54552617 7.15571019 -1.39776243 1.36692534 -1.88909668 -6.18803453 -10.21234077 -8.22405093 -9.23876931 8.14555067 0.34291775
10 30 15 80 99 85 23 16 53 71 38 44 54
-6.73316821 2.80866625 1.88918155 3.26064640 -3.06896221 -3.72892140 -4.84853389 -0.44968635 -11.53890858 -4.83509140 -9.12473157 5.84074956 0.48655993
93 97 39 82 2 84 92 76 60 41 69 88 90
-2.49499509 1.78350944 -2.07119866 -0.09541769 -2.61132764 -7.13839842 1.13286076 -2.22649789 1.70945033 -6.17799578 1.72007248 -7.30759854 4.40874681
95 14 59 100 33 66 43 98 25 55 7 19 81
-3.57752862 2.22975392 -0.63589766 -5.44952858 -4.26985257 0.72807131 -7.38922570 -6.83987480 -0.11549540 10.21944053 6.74082071 -9.27371592 1.87665457
87 64 96 13 29 94 56 27 49
-4.39576478 6.90841001 8.38828200 -7.20398759 -0.98490498 -7.13502356 -7.14203152 2.58218323 4.65594828
$RMSE
[,1]
[1,] 0.5246706
> CV.lm(Data, k = 2, p = c(0.3, 0.7), method = 'MCMC')
$yhat
18 26 4 65 73 23 93 61 49 52 32 96 35 24 70
-1.18704798 -2.22763527 -0.70294619 1.54361154 -1.90441413 -2.69264309 1.34857788 -2.01483135 0.72469515 -0.55287414 1.13532314 0.13860175 0.34771288 -0.32094229 2.42213708
40 55 13 74 57 28 86 54 9 77 51 53 90 5 21
-0.55299565 1.10809922 -0.12173624 1.73395684 -1.17323646 -4.69171818 1.73944592 -1.55977312 -2.86850571 0.50488980 0.57255653 2.22518454 -0.02745386 0.45510601 -3.04409607
46 94 81 62 72 63 25 97 29 8 89 7 80 31 12
3.65177053 -1.88996439 -1.49158575 0.45341147 1.00988180 -1.37892843 -1.02138495 6.12285936 0.14682676 1.89112537 -0.22454335 -1.08578317 0.05420828 -0.35956107 1.45192502
88 48 100 39 41 45 58 20 71 79 91 69 56 76 3
-0.32186855 0.77325767 3.04858813 -0.56988476 -0.58851572 0.59396908 1.12596000 -1.54215044 0.53622862 -1.68363330 -1.31871060 1.33894845 0.31806274 -0.06714099 3.25564750
30 68 43 92 60 78 42 95 75 10 2 19 44 82 66
3.63180068 -1.56899627 0.25288263 -0.90868528 0.01835371 -3.09540446 -1.00224356 1.30433837 1.83360141 0.06254247 -2.45426180 -0.40325555 -0.40294029 -0.90668144 -0.29552699
98 11 17 33 64 27 50 15 47 34 14 85 84 99 1
2.42805749 -1.26225061 -1.80221627 -2.32232043 -1.12199864 0.05185008 -0.44916037 0.48802865 -1.11261560 -0.50571384 -0.34688728 0.53972095 -0.59034839 -1.65673992 0.98014780
38 37 36 67 22 87 59 83 6 16
1.07202486 0.02004368 1.80763711 -1.26913808 0.67916018 -2.00694380 1.41128510 1.15813164 0.83946964 0.83548532
$RMSE
[,1]
[1,] 0.2736122
>
可以看出函数是可以使用的,因为每次抽样带有随机性,下面执行100次不同的CV。
2、数值模拟
2.1 CV模拟
N <- c(100, 300, 500, 1000)
K <- c(5, 10, 20, 50)
RMSEs <- data.table(n = c(), K = c(), RMSE = c())
for(j in 1:100)
for(n in N)
for(k in K)
Data <- getData(n = n, p = 20)
RMSE <- CV.lm(Data, k = k, method = 'CV')$RMSE
RMSEs <- rbind(RMSEs, data.table(n = n, k = k, RMSE = RMSE))
cat(j, '次', ',n = ', n, ',k = ', k, '\\n')
> RMSEs[,.(mean(RMSE.V1), sd(RMSE.V1)),by = c('n', 'k')]
n k RMSE_mean RMSE_sd
1: 100 5 0.5796997 0.055523027
2: 100 10 0.5579595 0.046627906
3: 100 20 0.5698858 0.047016860
4: 100 50 0.5613656 0.045983065
5: 300 5 0.5207754 0.023325919
6: 300 10 0.5188533 0.020100251
7: 300 20 0.5191463 0.022133704
8: 300 50 0.5187397 0.020079312
9: 500 5 0.5101782 0.016237246
10: 500 10 0.5100151 0.016728890
11: 500 20 0.5114958 0.016420630
12: 500 50 0.5108451 0.017669802
13: 1000 5 0.5050968 0.012085097
14: 1000 10 0.5032225 0.012533446
15: 1000 20 0.5079430 0.009645647
16: 1000 50 0.5044087 0.009994131
> p1 <- ggplot(data = RMSEs)
> p1 + geom_boxplot(aes(y = RMSE.V1, x = n
+ , group = n, col = factor(n))) +
+ facet_wrap(k ~ .)
>
2.2 MCMC模拟
P <- c(0.1, 0.2, 0.3, 0.4, 0.5, 0.5以上是关于R语言入门——多元回归交叉验证的实现的主要内容,如果未能解决你的问题,请参考以下文章