嵌套交叉验证的正确程序是啥?
Posted
技术标签:
【中文标题】嵌套交叉验证的正确程序是啥?【英文标题】:What is the correct procedure for nested cross-validation?嵌套交叉验证的正确程序是什么? 【发布时间】:2021-01-22 02:18:51 【问题描述】:我正在尝试使用 scikit-learn 制作分类器,然后预测分类器的准确性。我的数据集相对较小,我不确定最佳参数。因此,我转向嵌套交叉验证 (nCV) 来制作和测试我的模型。
我一直在努力了解最好的方法。不过看完后:
-
https://stats.stackexchange.com/questions/229509/do-i-need-an-initial-train-test-split-for-nested-cross-validation
https://stats.stackexchange.com/questions/410118/cross-validation-vs-train-validation-test/410206
https://stats.stackexchange.com/questions/95797/how-to-split-the-dataset-for-cross-validation-learning-curve-and-final-evaluat
我仍然不知道最好的方法。
到目前为止,我有:
-
将整个数据集拆分 (80%/20%) 为训练集和测试集
定义了我的内部 cv、外部 cv、参数网格和估计器(随机森林)
运行 nCV 以获得平均准确度得分。
为此,我目前的代码是:
X_train, X_test, Y_train, Y_test = train_test_split(X_res, Y_res, test_size=0.2)
inner_cv = KFold(n_splits=2, shuffle=True)
outer_cv = KFold(n_splits=2, shuffle=True)
rfc = RandomForestClassifier()
param_grid = 'bootstrap': [True, False],
'max_depth': [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, None],
'max_features': ['auto', 'sqrt', 'log2', None],
'min_samples_leaf': [1, 2, 4, 25],
'min_samples_split': [2, 5, 10, 25],
'criterion': ['gini', 'entropy'],
'n_estimators': [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
rfclf = RandomizedSearchCV(rfc, param_grid, cv=inner_cv, n_iter=100, n_jobs=-1, scoring='accuracy', verbose=1)
nested_cv_results = cross_val_score(rfclf, X_train, Y_trin, cv=outer_cv, scoring = 'accuracy')
我现在有 2 个问题:
-
如何找到总体上最好的模型?
如何针对 X_test 和 Y_test 测试这个最佳模型?
【问题讨论】:
【参考方案1】:交叉验证用于评估模型性能或调整超参数。假设您使用 CV 来调整您的超参数,您不能使用这些 CV 分数来评估模型性能,即,由于数据泄漏,您会得到一个过度乐观的估计。这就是嵌套 CV 可以帮助您的地方。通过添加额外的 CV 层,您可以防止数据泄漏。因此,嵌套 CV 用于获得模型性能的无偏估计。
为回答您的问题,在您对 X_train/y_train 完成嵌套 CV 后,您已获得对模型性能的无偏估计。接下来,使用 X_train/y_train 上的 RandomizedSearchCV 再次调整模型超参数。从此搜索中获得最佳模型并将其用于您的 X_test/y_test。
示例代码:
X_train, X_test, Y_train, Y_test = train_test_split(X_res, Y_res, test_size=0.2)
inner_cv = KFold(n_splits=2, shuffle=True)
outer_cv = KFold(n_splits=2, shuffle=True)
rfc = RandomForestClassifier()
param_grid = 'bootstrap': [True, False],
'max_depth': [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, None],
'max_features': ['auto', 'sqrt', 'log2', None],
'min_samples_leaf': [1, 2, 4, 25],
'min_samples_split': [2, 5, 10, 25],
'criterion': ['gini', 'entropy'],
'n_estimators': [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
rfclf = RandomizedSearchCV(rfc, param_grid, cv=inner_cv, n_iter=100, n_jobs=-1, scoring='accuracy', verbose=1, refit=True)
nested_cv_results = cross_val_score(rfclf, X_train, Y_train, cv=outer_cv, scoring = 'accuracy')
random = RandomizedSearchCV(rfc, param_grid, cv=inner_cv, n_iter=100, n_jobs=-1, scoring='accuracy', verbose=1, refit=True)
random.fit(X_train, Y_train)
random.best_estimator_.score(X_test, Y_test)
【讨论】:
以上是关于嵌套交叉验证的正确程序是啥?的主要内容,如果未能解决你的问题,请参考以下文章
SKlearn中具有嵌套交叉验证的分类报告(平均值/个体值)
使用 sklearn 在嵌套交叉验证中使用 GroupKFold