如何在 GridSearchCV 的 keras 模型的超参数优化中使用简单的验证集?

Posted

技术标签:

【中文标题】如何在 GridSearchCV 的 keras 模型的超参数优化中使用简单的验证集?【英文标题】:How to use a simple validation set in hyperparameter optimization of keras model with GridSearchCV? 【发布时间】:2020-10-04 06:29:53 【问题描述】:

我正在尝试对大型数据集执行超参数优化。而且我想避免使用交叉验证cv 来加速优化。这就是为什么我想使用来自训练数据集的验证拆分 = 0.2 的验证集。

   grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1, cv=3)
   grid_result = grid.fit(X_train, y_train)

我应该如何修改上面的 GridSearchCV() 参数以使用带有validation_split=0.2 的验证数据集并忽略交叉验证来执行超参数优化?

【问题讨论】:

您的意思是要始终为 hyper opt 使用相同的数据集吗? 是的,我想每次都使用相同的验证集来评估网格搜索 我添加了一个答案...不要忘记投票并接受它作为答案;-) 如果有问题请告诉我 这能回答你的问题吗? Using explicit (predefined) validation set for grid search with sklearn 【参考方案1】:

使用 PredefinedSplit,您可以为超参数选择使用相同的验证集。 -1 识别您的火车数据,而 0 识别您的有效数据

from sklearn.model_selection import PredefinedSplit, GridSearchCV

X_train = np.random.uniform(0,1, (10000,30))
y_train = np.random.uniform(0,1, 10000)
val_spilt = np.random.choice([-1,0], len(y_train), p=[0.8, 0.2])

grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1, 
                    cv=PredefinedSplit(val_spilt))
grid_result = grid.fit(X_train, y_train)

在此处进行手动检查:

ps = PredefinedSplit(val_spilt)
for train_index, val_index in ps.split():
    print("TRAIN:", len(train_index), "VAL:", len(val_index))

【讨论】:

以上是关于如何在 GridSearchCV 的 keras 模型的超参数优化中使用简单的验证集?的主要内容,如果未能解决你的问题,请参考以下文章

使用 keras 在虹膜上的 GridSearchCV 结果不佳

使用 Keras 和 sklearn GridSearchCV 交叉验证提前停止

Keras训练神经网络进行分类并使用GridSearchCV进行参数寻优

如何保存 GridSearchCV 对象?

将 gridsearchCV 与 Keras RNN-LSTM 一起使用时出现尺寸错误

keras + scikit-learn 包装器,当 GridSearchCV 与 n_jobs >1 时似乎挂起