如果我们在管道中包含转换器,来自 scikit-learn 的“cross_val_score”和“GridsearchCV”的 k 折交叉验证分数是不是存在偏差?

Posted

技术标签:

【中文标题】如果我们在管道中包含转换器,来自 scikit-learn 的“cross_val_score”和“GridsearchCV”的 k 折交叉验证分数是不是存在偏差?【英文标题】:Are the k-fold cross-validation scores from scikit-learn's `cross_val_score` and `GridsearchCV` biased if we include transformers in the pipeline?如果我们在管道中包含转换器,来自 scikit-learn 的“cross_val_score”和“GridsearchCV”的 k 折交叉验证分数是否存在偏差? 【发布时间】:2019-12-30 06:20:28 【问题描述】:

应使用 StandardScaler 等数据预处理器对训练集进行 fit_transform,并且仅对测试集进行转换(不拟合)。我希望相同的拟合/转换过程适用于调整模型的交叉验证。但是,我发现cross_val_scoreGridSearchCV 用预处理器拟合了整个训练集(而不是 fit_transform 内部训练集,并转换内部验证集)。我相信这人为地消除了 inner_validation 集中的方差,这使得 cv 分数(用于通过 GridSearch 选择最佳模型的指标)有偏差。这是一个问题还是我真的错过了什么?

为了演示上述问题,我使用来自 Kaggle 的 Breast Cancer Wisconsin (Diagnostic) Data Set 尝试了以下三个简单的测试用例。

    我故意用StandardScaler() 拟合和变换整个X
X_sc = StandardScaler().fit_transform(X)
lr = LogisticRegression(penalty='l2', random_state=42)
cross_val_score(lr, X_sc, y, cv=5)
    我在 Pipeline 中包含 SC 和 LR 并运行 cross_val_score
pipe = Pipeline([
    ('sc', StandardScaler()),
    ('lr', LogisticRegression(penalty='l2', random_state=42))
])
cross_val_score(pipe, X, y, cv=5)
    与 2 相同,但使用 GridSearchCV
pipe = Pipeline([
    ('sc', StandardScaler()),
    ('lr', LogisticRegression(random_state=42))
])
params = 
    'lr__penalty': ['l2']

gs=GridSearchCV(pipe,
param_grid=params, cv=5).fit(X, y)
gs.cv_results_

它们都产生相同的验证分数。 [0.9826087 , 0.97391304, 0.97345133, 0.97345133, 0.99115044]

【问题讨论】:

【参考方案1】:

学习预测函数的参数并在相同的数据上对其进行测试是一个方法错误:一个模型只会重复它刚刚看到的样本的标签,它会获得完美的分数,但无法预测任何有用的东西关于尚未看到的数据。这种情况称为过拟合。为了避免这种情况,在执行(监督)机器学习实验时,通常的做法是保留部分可用数据作为测试集 X_test, y_test

解决此问题的方法是一种称为交叉验证(简称 CV)的过程。仍应保留测试集以进行最终评估,但在进行 CV 时不再需要验证集。在称为 k-fold CV 的基本方法中,训练集被分成 k 个更小的集合(其他方法在下面描述,但通常遵循相同的原则)。对于 k 个“折叠”中的每一个,都遵循以下过程:

使用折叠作为训练数据来训练模型; 生成的模型在数据的剩余部分上进行验证(即,它被用作测试集来计算诸如准确性之类的性能度量)。 然后,k 折交叉验证报告的性能度量是循环中计算的值的平均值。这种方法的计算成本可能很高,但不会浪费太多数据(就像修复任意验证集时的情况一样),这在样本数量非常少的逆推理等问题中是一个主要优势。

此外,如果您的模型从一开始就存在偏差,我们必须通过 SMOTE/Less 目标变量的过采样/High 目标变量的欠采样来使其平衡。

【讨论】:

感谢您的快速回复和详细解释。您回复中的第二段是我的问题真正关注的地方。以图中的split 1 为例: 欢迎。如果您对我的回答感到满意,请投票。 感谢您的快速回复和详细解释。您回复中的第二段是我的问题真正关注的地方。以图中的 split 1 为例:在 CV 期间,像 StandardScaler() 这样的预处理器应该 fit.transform 折叠 2-5(inner_train 集)并且只有 transform 折叠 1(验证集)。我认为 cross_val_score() 和 GridSearchCV 在计算 CV 分数时都不会这样做。相反,StandardScale() 是 fit.transform 整个折叠 1-5,我认为这是一个问题(如果为真)。希望这能澄清我的问题。再次感谢。 Aniruddha,感谢您在回复中发布 sklearn 的交叉验证图。我实际上在我的博客文章中使用了它。 好的,没问题。你投票给我的帖子。【参考方案2】:

不,sklearn 不会对整个数据集执行 fit_transform

为了检查这一点,我将StandardScaler 子类化以打印发送给它的数据集的大小。

class StScaler(StandardScaler):
    def fit_transform(self,X,y=None):
        print(len(X))
        return super().fit_transform(X,y)

如果您现在在代码中替换 StandardScaler,您会看到在第一种情况下传递的数据集大小实际上更大。

但为什么准确度保持不变?我认为这是因为LogisticRegression 对特征尺度不是很敏感。如果我们改为使用对比例非常敏感的分类器,例如KNeighborsClassifier,您会发现两种情况之间的准确度开始发生变化。

X,y = load_breast_cancer(return_X_y=True)
X_sc = StScaler().fit_transform(X)
lr = KNeighborsClassifier(n_neighbors=1)
cross_val_score(lr, X_sc,y, cv=5)

输出:

569
[0.94782609 0.96521739 0.97345133 0.92920354 0.9380531 ]

还有第二种情况,

pipe = Pipeline([
    ('sc', StScaler()),
    ('lr', KNeighborsClassifier(n_neighbors=1))
])
print(cross_val_score(pipe, X, y, cv=5))

输出:

454
454
456
456
456
[0.95652174 0.97391304 0.97345133 0.92920354 0.9380531 ]

在准确性方面变化不大,但仍然有所改变。

【讨论】:

这很有帮助,也是我正在寻找的确切答案。我用RandomForestClassifier 尝试了自己,现在CV 分数显示出差异。我从你的回复中学到了很多。谢谢! Shihab,再次感谢。我很惊讶像我这样的许多 Bootcamp 学生在使用 GridSearchCV 时实际上错误地应用了预处理器(在使用 GridSearch 调整模型之前预处理训练数据)。我打算写一篇关于这个主题的博客文章。如果您不介意,我想引用此讨论中的一些文本和代码,并将解决方案归功于您。让我知道这是否可以。谢谢! 谢谢!。完成后将转发链接。您的 cmets 将不胜感激。 Shihab,这里是博客文章“使用管道预处理数据以防止交叉验证期间数据泄漏”的链接。如果您有任何 cmets,请告诉我。 towardsdatascience.com/… @ShihabShahriarKhan Shihab,感谢您和 Kai 提供的有用的 QA。当我调用``` pipe = Pipeline([ ('sc', StandardScaler()), ('model', model(**parameters, random_state=42)) ])``` 然后我调用learning_curve(pipe, X_train, y_train, cv=RepeatedStratifiedKFold(n_splits=nb_splits, n_repeats=nb_repeats, random_state=42), scoring='accuracy') 确实这也仅将标准化应用于训练并将转换应用于cv loop 内的验证(即避免数据泄漏)?

以上是关于如果我们在管道中包含转换器,来自 scikit-learn 的“cross_val_score”和“GridsearchCV”的 k 折交叉验证分数是不是存在偏差?的主要内容,如果未能解决你的问题,请参考以下文章

在剑道的日期管道中包含变量值(动态格式):AngularJS

在 CMAKE 中包含来自 Android 项目的不同文件夹的静态库

Rails 5 - 如何在资产管道中包含所有供应商资产?

在 C++ 中包含来自单独文件夹的文件

如何在禁用管道中包含“jquery_ujs”以便能够对表单使用“remote:true”?

>=Rails 3.1 如何在资产管道中包含 IE 特定的 YAML-CSS 文件