sklearn 中关于 GridSearchCV 的说明

Posted

技术标签:

【中文标题】sklearn 中关于 GridSearchCV 的说明【英文标题】:Clarifications on GridSearchCV in sklearn 【发布时间】:2020-07-09 03:13:44 【问题描述】:

我对 sklearn 中的 GridSearchCV 有以下问题。我试过但找不到明确的答案。 下面是我使用的代码补丁 -

dep = df2['responder_flag']
indep = df2.drop(df2.columns[[0,85]], axis = 1)

X_train, X_test, y_train, y_test = train_test_split(indep, dep,test_size=0.25, random_state = 23)

train = xgb.XGBClassifier(objective='binary:logistic')
param_grid = 'max_depth': [4,5], 'n_estimators': [500], 'learning_rate': [0.02,0.01]
grid = GridSearchCV(train, param_grid,cv=5, scoring='roc_auc')
grid.fit(X_train, y_train)

    cross_validationGridSearchCV 中的 cv 参数是否等同于 Kfold 或其他在训练数据时使用 cross_validation_score 和其他类似函数显式应用的 CV 技术?

    我可以使用GridsearchCV 进行交叉验证吗? 说如果我不提供多个参数列表,它是否等于交叉验证技术?

    一旦执行了grid.fit(X_train, y_train) 语句,是否会根据已识别的最佳参数训练模型并可直接用于模型预测,或者我是否需要使用grid.best_params_ 定义另一个估计器然后训练并使用它用于预测?

抱歉,如果这些问题早些得到答复。

【问题讨论】:

关于第1点,cross_validation参数是什么意思? 嘿,我的意思是 GridSearchCV 中的 cv 参数;如代码中所述:(cv = 5) 【参考方案1】:

以下是答案:

    cv参数相当于k-fold。 在GridSearchCV 中,我们给出了一组我们希望模型采用的参数值。假设我们从 [0.0001, 0.001, 0.01, 0.1, 1, 10] 中取 learning_rate = 0.0001。当我们在gridsearch 中指定 cv=5 时,它将为000.1 执行5-fold cv。同样,它也会对剩余的值执行5-fold cv。在这种情况下,k 是 5。

    从某种意义上说,是的。但不要这样做,因为 GridSearchCV 需要一个参数列表。 GridSearchCV 是一种执行超参数调优的方法。如果你不指定多个参数列表,它就违背了使用 GridSearch 的目的。

    在完成grid.fit(X_train, y_train) 后,无需手动在训练集上拟合带有grid.best_params_ 的模型。 GridSearchv 有一个名为refit 的参数,如果我们设置refit = True,它将自动将grid.best_esitmator_ 重新拟合到整个训练集。默认设置为True。 Documentation

【讨论】:

感谢索拉布的回答!我问第 2 步的原因是因为我试图将 cv 应用于 xgboost(不是 sk-learn 中的 xgboost),并且无法使用 xgboost.cv 函数正确处理它。认为这是一种解决方法。因此,如果我发送一个参数列表,例如 - params_list = 'max_depth': [4], 'n_estimators': [500], 'learning_rate': [0.01],它将有助于处理交叉验证(技术上) 我相信。这是正确的理解吗? 如果您不需要超参数调优,那么GridSearchCV 不是可行的方法,因为像这样为 GridSearchCV 使用模型的默认参数,只会产生一个参数网格组合,所以它就像只执行 CV。这样做是没有意义的 - 如果我正确理解了你的问题 cv=5 处理分类问题,它将执行stratifiedKfold 而不仅仅是kfold

以上是关于sklearn 中关于 GridSearchCV 的说明的主要内容,如果未能解决你的问题,请参考以下文章

sklearn中关于空数组的弃用错误,我的代码中没有任何空数组

GridSearchCV(sklearn)中的多个估计器

GridsearchCV sklearn 中的错误

你能从 sklearn 网格搜索 (GridSearchCV) 中获得所有估计器吗?

无法使用 sklearn 的 GridSearchCV 运行 tflearn

GridSearchCV/RandomizedSearchCV 与 sklearn 中的 partial_fit