交叉验证给出负 R2?

Posted

技术标签:

【中文标题】交叉验证给出负 R2?【英文标题】:Cross-validation gives Negative R2? 【发布时间】:2019-04-24 16:34:21 【问题描述】:

为了简单起见,我将 500 个样本从 10,000+ 行数据集中划分出来。请将 X 和 y 复制并粘贴到您的 IDE 中。

X =

array([ -8.93,  -0.17,   1.47,  -6.13,  -4.06,  -2.22,  -2.11,  -0.25,
         0.25,   0.49,   1.7 ,  -0.77,   1.07,   5.61, -11.95,  -3.8 ,
        -3.42,  -2.55,  -2.44,  -1.99,  -1.7 ,  -0.98,  -0.91,  -0.91,
        -0.25,   1.7 ,   2.88,  -6.9 ,  -4.07,  -1.35,  -0.33,   0.63,
         0.98,  -3.31,  -2.61,  -2.61,  -2.17,  -1.38,  -0.77,  -0.25,
        -0.08,  -1.2 ,  -3.1 ,  -1.07,  -0.7 ,  -0.41,  -0.33,   0.41,
         0.77,   0.77,   1.14,   2.17,  -7.92,  -3.8 ,  -2.11,  -2.06,
        -1.2 ,  -1.14,   0.  ,   0.56,   1.47,  -1.99,  -0.17,   2.44,
        -5.87,  -3.74,  -3.37,  -2.88,  -0.49,  -0.25,  -0.08,   0.33,
         0.33,   0.84,   1.64,   2.06,   2.88,  -4.58,  -1.82,  -1.2 ,
         0.25,   0.25,   0.63,   2.61,  -5.36,  -1.47,  -0.63,   0.  ,
         0.63,   1.99,   1.99, -10.44,  -2.55,   0.33,  -8.93,  -5.87,
        -5.1 ,  -2.78,  -0.25,   1.47,   1.93,   2.17,  -5.36,  -5.1 ,
        -3.48,  -2.44,  -2.06,  -2.06,  -1.82,  -1.58,  -1.58,  -0.63,
        -0.33,   0.  ,   0.17,  -3.31,  -0.25,  -5.1 ,  -3.8 ,  -2.55,
        -1.99,  -1.7 ,  -0.98,  -0.91,  -0.63,  -0.25,   0.77,   0.91,
         0.91,  -9.43,  -8.42,  -2.72,  -2.55,  -1.26,   0.7 ,   0.77,
         1.07,   1.47,   1.7 ,  -1.82,  -1.47,   0.17,   1.26,  -5.36,
        -1.52,  -1.47,  -0.17,  -3.48,  -3.31,  -2.06,  -1.47,   0.17,
         0.25,   1.7 ,   2.5 ,  -9.94,  -6.08,  -5.87,  -3.37,  -2.44,
        -2.17,  -1.87,  -0.98,  -0.7 ,  -0.49,   0.41,   1.47,   2.28,
       -14.95, -12.44,  -6.39,  -4.33,  -3.8 ,  -2.72,  -2.17,  -1.2 ,
         0.41,   0.77,   0.84,   2.51,  -1.99,  -1.7 ,  -1.47,  -1.2 ,
         0.49,   0.63,   0.84,   0.98,   1.14,   2.5 ,  -2.06,  -1.26,
        -0.33,   0.17,   4.58,  -7.41,  -5.87,   1.2 ,   1.38,   1.58,
         1.82,   1.99,  -6.39,  -2.78,  -2.67,  -1.87,  -1.58,  -1.47,
         0.84, -10.44,  -7.41,  -3.05,  -2.17,  -1.07,  -1.07,  -0.91,
         0.25,   1.82,   2.88,  -6.9 ,  -1.47,   0.33,  -8.42,  -3.8 ,
        -1.99,  -1.47,  -1.47,  -0.56,   0.17,   0.17,   0.25,   0.56,
         4.58,  -3.48,  -2.61,  -2.44,  -0.7 ,   0.63,   1.47,   1.82,
       -13.96,  -9.43,  -2.67,  -1.38,  -0.08,   0.  ,   1.82,   3.05,
        -4.58,  -3.31,  -0.98,  -0.91,  -0.7 ,   0.77,  -0.7 ,  -0.33,
         0.56,   1.58,   1.7 ,   2.61,  -4.84,  -4.84,  -4.32,  -2.88,
        -1.38,  -0.98,  -0.17,   0.17,   0.49,   2.44,   4.32,  -3.48,
        -3.05,   0.56,  -8.42,  -3.48,  -2.61,  -2.61,  -2.06,  -1.47,
        -0.98,   0.  ,   0.08,   1.38,   1.93,  -9.94,  -2.72,  -1.87,
        -1.2 ,  -1.07,   1.58,   4.58,  -6.64,  -2.78,  -0.77,  -0.7 ,
        -0.63,   0.49,   1.07,  -8.93,  -4.84,  -1.7 ,   1.76,   3.31,
       -11.95,  -3.16,  -3.05,  -1.82,  -0.49,  -0.41,   0.56,   1.58,
       -13.96,  -3.05,  -2.78,  -2.55,  -1.7 ,  -1.38,  -0.91,  -0.33,
         1.2 ,   1.32,   1.47,  -2.06,  -1.82,  -7.92,  -6.33,  -4.32,
        -3.8 ,  -1.93,  -1.52,  -0.98,  -0.49,  -0.33,   0.7 ,   1.52,
         1.76,  -8.93,  -7.41,  -2.88,  -2.61,  -2.33,  -1.99,  -1.82,
        -1.64,  -0.84,   1.07,   2.06,  -3.96,  -2.44,  -1.58,   0.  ,
        -3.31,  -2.61,  -1.58,  -0.25,   0.33,   0.56,   0.84,   1.07,
        -1.58,  -0.25,   1.35,  -1.99,  -1.7 ,  -1.47,  -1.47,  -0.84,
        -0.7 ,  -0.56,  -0.33,   0.56,   0.63,   1.32,   2.28,   2.28,
        -2.72,  -0.25,   0.41,  -6.9 ,  -4.42,  -4.32,  -1.76,  -1.2 ,
        -1.14,  -1.07,   0.56,   1.32,   1.52, -14.97,  -7.41,  -5.1 ,
        -2.61,  -1.93,  -0.98,   0.17,   0.25,   0.41,  -4.42,  -2.61,
        -0.91,  -0.84,   2.39,  -2.61,  -1.32,   0.41,  -6.9 ,  -5.61,
        -4.06,  -3.31,  -1.47,  -0.91,  -0.7 ,  -0.63,   0.33,   1.38,
         2.61,  -2.29,   3.06,   4.44, -10.94,  -4.32,  -3.42,  -2.17,
        -1.7 ,  -1.47,  -1.32,  -1.07,  -0.7 ,   0.  ,   0.77,   1.07,
        -3.31,  -2.88,  -2.61,  -1.47,  -1.38,  -0.63,  -0.49,   1.07,
         1.52,  -3.8 ,  -1.58,  -0.91,  -0.7 ,   0.77,   3.42,  -8.42,
        -2.88,  -1.76,  -1.76,  -0.63,  -0.25,   0.49,   0.63,  -6.9 ,
        -4.06,  -1.82,  -1.76,  -1.76,  -1.38,  -0.91,  -0.7 ,   0.17,
         1.38,   1.47,   1.47, -11.95,  -0.98,  -0.56, -14.97,  -9.43,
        -8.93,  -2.72,  -2.61,  -1.64,  -1.32,  -0.56,  -0.49,   0.91,
         1.2 ,   1.47,  -3.8 ,  -3.06,  -2.51,  -1.04,  -0.33,  -0.33,
        -3.31,  -3.16,  -3.05,  -2.61,  -1.47,  -1.07,   2.17,   3.1 ,
        -2.61,  -0.25,  -3.85,  -2.44])

y =

array([1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0,
       1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1,
       0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1,
       1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1,
       1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
       1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,
       0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
       0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,
       1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0,
       0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1,
       1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0,
       1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
       0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0,
       1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1,
       1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0,
       0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0,
       0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0,
       0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0,
       1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1])

初始化和训练:

from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X, y)

交叉验证:

from sklearn.model_selection import cross_val_score
cross_val_score(model, X, y, cv=10, scoring='r2').mean()

-0.3339677563815496(负R2?)

看它是否接近模型的真实R2。我这样做了:

from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=None, shuffle=False)

r2_score(y_test, model.predict_proba(X_test)[:,1], multioutput='variance_weighted')

0.32642659661798396

这个 R2 对模型的拟合优度更有意义,看起来两个 R2 只是一个 +/- 符号开关,但事实并非如此。在我使用更大样本的模型中,R2 cross-val 为 -0.24,R2 test 为 0.18。而且,当我添加一个似乎有利于模型的功能时,R2 测试会上升,而 R2 交叉验证会下降

此外,如果您将 LogisticRegression 切换为 LinearRegression,R2 交叉验证现在为正,并且接近 R2 测试。是什么导致了这个问题?

【问题讨论】:

我必须通过X = X.reshape((-1, 1))X 提供给LogisticRegression。你也是这样吗?如果是,您可以编辑代码吗? 'model.predict_proba(X)[:,1]' 我忘记在帖子中添加这个,但后来编辑了 但是你为什么要在分类问题上使用r2-scoreLogisticRegression 是一个分类器,您的目标似乎也是二进制类。 LogisticRegression 被编程为给出离散值,因此 R-Squared 不适合在这种情况下检查性能。其他情况下的 LinearRegression 是回归量 @VivekKumar 你会使用什么拟合优度测试? 为什么需要分类拟合优度。分类问题涉及基于特征识别类之间的最佳边界,而不是在特征上拟合曲线。您可以尝试使用accuracyrecallprecisionauc 等。见this page for more details 【参考方案1】:

TLDR:R2 可能为负数,您误解了train_test_split 结果。

我将在下面解释这两种说法。

cross_val_score 符号翻转为 errorloss 指标

从docs,您可以看到cross_val_score 实际上翻转了某些指标的符号。但仅适用于errorloss 指标(越低越好),不适用于score 指标(越高越好):

所有记分器对象都遵循较高返回值优于较低返回值的约定。因此,衡量模型和数据之间距离的指标,如 metrics.mean_squared_error,可以作为 neg_mean_squared_error 使用,它返回指标的否定值。

由于r2score 指标,因此它不会翻转符号。您将在交叉验证中获得 -0.33。请注意,这是正常的。来自r2_scoredocs:

最好的分数是 1.0,可能是负数(因为模型可以任意变差)。始终预测 y 的期望值的常量模型,不考虑输入特征,将获得 0.0 的 R^2 分数。

所以这将我们引向第二部分:为什么使用 CV 和训练/测试拆分得到如此不同的结果?

CV 和训练/测试分割结果的区别

使用train_test_split 获得更好结果的原因有两个。

根据概率而不是类评估 r2(您使用 predict_proba 而不是 predict 可以减少错误的危害:

print(r2_score(y_test, model.predict_proba(X_test)[:,1], multioutput='variance_weighted'))
 0.19131536389654913

同时:

 print(r2_score(y_test, model.predict(X_test)))
 -0.364200082678793

10 folds cv 的平均值,而不检查方差,这很高。如果你检查方差和结果的细节,你会发现方差很大:

scores = cross_val_score(model, X, y, cv=10, scoring='r2')
scores
array([-0.67868339, -0.03918495,  0.04075235, -0.47783251, -0.23152709,
   -0.39573071, -0.72413793, -0.66666667,  0.        , -0.16666667])

scores.mean(), scores.std() * 2
(-0.3339677563815496, 0.5598543351649792)

希望对您有所帮助!

【讨论】:

哇。所以cross_val默认使用.predict而不是.predict_proba...有没有办法让cross_val_score使用proba? 我刚刚在其他答案中发现了这一点。 cross_val_predict(method=‘predict_proba’) 如何将 cross_val 用于 predict_proba?我使用了cross_val_predict,但我才知道它不是用于计算 R2。 嗨,您可以使用 scoring="neg_log_loss",它使用 predict_proba 进行计算,并且更适合回归模型,正如 Vivek 在对问题的评论中所说。请注意,它会返回负值,因为它被否定了,而且它是一个错误度量,因此较低的绝对值更好。【参考方案2】:

R2 可能为负数。以下段落来自“确定系数”的wikipedia页面

在某些情况下,R2 的计算定义可能会产生负值,具体取决于所使用的定义。当与相应结果进行比较的预测没有从使用这些数据的模型拟合过程中得出时,就会出现这种情况。即使使用了模型拟合程序,R2 仍可能为负值,例如当进行线性回归而不包括截距时,或使用非线性函数拟合数据时。根据这一特定标准,在出现负值的情况下,数据的平均值比拟合函数值更适合结果。由于决定系数的最一般定义也称为 Nash-Sutcliffe 模型效率系数,因此最后一种表示法在许多领域中是首选,因为它表示可以从 -infinity 到 1 变化的拟合优度指标(即, 它可以产生负值) 与平方字母混淆。

看来预测比水平线差。

【讨论】:

以上是关于交叉验证给出负 R2?的主要内容,如果未能解决你的问题,请参考以下文章

pls R 的交叉验证中如何计算 R2 和 RMSE

R 与 scikit-learn 中用于线性回归 R2 的交叉验证

sklearn中的Kfold交叉验证每次都会给出不同的折叠

Scikit-learn SVC 在随机数据交叉验证中总是给出准确度 0

使用 R 对 randomForest 执行交叉验证

交叉验证与网格搜索