Scikit GridSearchCV - fit() 和 predict() 如何与 ColumnTranformers 和 Pipelines 结合使用

Posted

技术标签:

【中文标题】Scikit GridSearchCV - fit() 和 predict() 如何与 ColumnTranformers 和 Pipelines 结合使用【英文标题】:Scikit GridSearchCV - How does fit() and predict() work in conjunction with ColumnTranformers and Pipelines 【发布时间】:2021-10-06 16:22:47 【问题描述】:

我对 GridSearchCV 的实际工作方式有点困惑,所以让我们想象一个任意回归问题,我想预测房子的价格:

假设我们使用一个简单的预处理器来对训练集进行目标编码: 目标编码器应调用 X_train 上的 fit_transform() 和 X_test 上的 transform() 以防止数据泄漏。

preprocessor = ColumnTransformer(
    transformers=
    [      
        ('encoded_target_price', TargetEncoder(), ["Zipcodes"]),  
    ],
     remainder='passthrough',n_jobs=-1)

我们使用一些带有缩放功能的管道,同样,缩放器应该可以在以下方面工作 训练和测试集。

pipe = Pipeline(steps=[("preprocessor", preprocessor),
                       ("scaler", RobustScaler()),
                       ('clf', LinearSVR()),
                      ])

使用一些任意参数初始化 GridSearch:

gscv = GridSearchCV(estimator = pipe, 
                    param_grid = tuned_parameters,                
                    cv = kfold,                                   
                    n_jobs = -1,
                    random_state=seed
                    )

现在我们可以拨打gscv.fit(X_train, ytrain)gscv.predict(X_test)

我不明白这是如何工作的。例如通过调用 fit() 目标编码器 和 Scaler 适合训练集,但它们永远不会被转换,因此数据永远不会改变。 GridSearch 如何根据未转换的训练集计算分数?

predict 方法我完全不懂。如何在不将preprocessor 的转换应用到测试集X_test 的情况下做出预测?我的意思是,当我在训练集上进行一些大的转换(例如缩放、编码等)时,它们也必须在测试集上进行吗?

但是Gridsearch内部只调用了best_estimator_.predict(),那么测试集上的.transform()是在哪里发生的呢?

【问题讨论】:

【参考方案1】:

在调用管道的predict() 函数时会隐式应用数据转换。在documentation中明确提到:

对数据应用变换,并使用最终估计器进行预测

因此无需显式转换数据。它在最终估计器做出预测之前自动完成。也没有数据泄露,因为管道在对数据应用predict()时会调用每一步的transform()方法。

【讨论】:

以上是关于Scikit GridSearchCV - fit() 和 predict() 如何与 ColumnTranformers 和 Pipelines 结合使用的主要内容,如果未能解决你的问题,请参考以下文章

Scikit 管道参数 - fit() 得到了一个意外的关键字参数“gamma”

scikit-learn 中的超参数优化(网格搜索)

GridSearchCV/RandomizedSearchCV 与 sklearn 中的 partial_fit

为啥 GridSearchCV 没有给出最好的分数? - Scikit 学习

Scikit-learn 在 DecisionTreeClassifier 上使用 GridSearchCV

scikit-learn GridSearchCV 弃用警告