在 Matlab 中计算广义线性模型的交叉验证

Posted

技术标签:

【中文标题】在 Matlab 中计算广义线性模型的交叉验证【英文标题】:Calculate cross validation for Generalized Linear Model in Matlab 【发布时间】:2014-07-17 05:42:01 【问题描述】:

我正在使用广义线性模型进行回归。使用crossVal 函数让我措手不及。到目前为止我的实现;

x = 'Some dataset, containing the input and the output'

X = x(:,1:7);
Y = x(:,8);

cvpart = cvpartition(Y,'holdout',0.3);
Xtrain = X(training(cvpart),:);
Ytrain = Y(training(cvpart),:);
Xtest = X(test(cvpart),:);
Ytest = Y(test(cvpart),:);

mdl = GeneralizedLinearModel.fit(Xtrain,Ytrain,'linear','distr','poisson');

Ypred  = predict(mdl,Xtest);
res = (Ypred - Ytest);
RMSE_test = sqrt(mean(res.^2));

下面的代码用于计算从link 获得的多重回归的交叉验证。我想要广义线性模型类似的东西。

c = cvpartition(Y,'k',10);
regf=@(Xtrain,Ytrain,Xtest)(Xtest*regress(Ytrain,Xtrain));
cvMse = crossval('mse',X,Y,'predfun',regf)

【问题讨论】:

我想使用 Matlab 统计工具箱中的函数计算 GLM 的交叉验证误差。现在,有什么我遗漏的,或者 Matlab 根本没有内置这样的功能。 【参考方案1】:

您可以手动执行交叉验证过程(为每个折叠训练模型、预测结果、计算错误,然后报告所有折叠的平均值),也可以使用包含整个过程的 CROSSVAL 函数一次通话。

举个例子,我将首先加载和准备一个数据集(cars dataset 的一个子集,随统计工具箱一起提供):

% load regression dataset
load carsmall
X = [Acceleration Cylinders Displacement Horsepower Weight];
Y = MPG;

% remove instances with missing values
missIdx = isnan(Y) | any(isnan(X),2);
X(missIdx,:) = [];
Y(missIdx) = [];

clearvars -except X Y

选项 1

这里我们将使用k-fold cross-validation 使用cvpartition(非分层)手动对数据进行分区。对于每一折,我们使用训练数据训练一个GLM 模型,然后使用该模型来预测测试数据的输出。接下来,我们计算并存储此折叠的回归 mean squared error。最后,我们报告所有分区的平均 RMSE。

% partition data into 10 folds
K = 10;
cv = cvpartition(numel(Y), 'kfold',K);

mse = zeros(K,1);
for k=1:K
    % training/testing indices for this fold
    trainIdx = cv.training(k);
    testIdx = cv.test(k);

    % train GLM model
    mdl = GeneralizedLinearModel.fit(X(trainIdx,:), Y(trainIdx), ...
        'linear', 'Distribution','poisson');

    % predict regression output
    Y_hat = predict(mdl, X(testIdx,:));

    % compute mean squared error
    mse(k) = mean((Y(testIdx) - Y_hat).^2);
end

% average RMSE across k-folds
avrg_rmse = mean(sqrt(mse))

选项 2

在这里,我们可以简单地使用适当的函数句柄调用 CROSSVAL,该函数句柄在给定一组训练/测试实例的情况下计算回归输出。参看文档页面了解参数。

% prediction function given training/testing instances
fcn = @(Xtr, Ytr, Xte) predict(...
    GeneralizedLinearModel.fit(Xtr,Ytr,'linear','distr','poisson'), ...
    Xte);

% perform cross-validation, and return average MSE across folds
mse = crossval('mse', X, Y, 'Predfun',fcn, 'kfold',10);

% compute root mean squared error
avrg_rmse = sqrt(mse)

与以前相比,您应该得到类似的结果(当然略有不同,因为交叉验证中涉及的随机性)。

【讨论】:

让我明确一点,crossval 函数如何知道它的数据的哪一部分是 Xtr,也就是 Ytr。我知道 fcn 是一个匿名函数,但是必须有人告诉它数据的哪一部分是哪一部分。我已经浏览了文档,但匿名函数的这种数据分离仍然让我难以理解。 @motiur:crossval 使用相同的技术在内部对数据进行分区,然后对于每个折叠,它调用函数句柄传递适当的训练/测试数据集。事实上,如果你愿意,你可以告诉crossval 使用你自己的cvpartition 对象..

以上是关于在 Matlab 中计算广义线性模型的交叉验证的主要内容,如果未能解决你的问题,请参考以下文章

R 与 scikit-learn 中用于线性回归 R2 的交叉验证

谁会用spss软件里的广义线性模型进行回归分析,文章急着要投出去得用到这个模型,希望会的棒棒忙,谢谢了

在插入符号交叉验证期间计算模型校准?

R语言广义线性模型函数GLMglm函数构建逻辑回归模型(Logistic regression)去除初步验证不具有显著性的特征再次构建逻辑回归模型简化模型(reduced model)

spss广义线性混合效应模型中的随机效应怎么交互

机器学习-广义线性模型