在 scikit-learn 中训练神经网络时提前停止
Posted
技术标签:
【中文标题】在 scikit-learn 中训练神经网络时提前停止【英文标题】:Early-stopping while training neural network in scikit-learn 【发布时间】:2014-03-22 15:49:13 【问题描述】:这个问题是针对 Python 库 scikit-learn 的。请让我知道将其发布在其他地方是否更好。谢谢!
现在的问题...
我有一个基于 BaseEstimator 的前馈神经网络类 ffnn,我使用 SGD 进行训练。它工作正常,我也可以使用 GridSearchCV() 并行训练它。
现在我想在函数 ffnn.fit() 中实现提前停止,但为此我还需要访问折叠的验证数据。一种方法是更改 sklearn.grid_search.fit_grid_point() 中的行
clf.fit(X_train, y_train, **fit_params)
变成类似的东西
clf.fit(X_train, y_train, X_test, y_test, **fit_params)
并更改 ffnn.fit() 以获取这些参数。这也会影响 sklearn 中的其他分类器,这是一个问题。我可以通过检查 fit_grid_point() 中的某种标志来避免这种情况,该标志告诉我何时以上述两种方式之一调用 clf.fit()。
在我不必编辑 sklearn 库中的任何代码的情况下,有人可以建议一种不同的方法吗?
或者,进一步将 X_train 和 y_train 随机拆分为训练/验证集并检查一个好的停止点,然后在所有 X_train 上重新训练模型是否正确?
谢谢!
【问题讨论】:
【参考方案1】:您可以让您的神经网络模型在内部使用 train_test_split
函数从传递的 X_train
和 y_train
中提取验证集。
编辑:
或者,进一步将 X_train 和 y_train 随机拆分为训练/验证集并检查一个好的停止点,然后在所有 X_train 上重新训练模型是否正确?
是的,但那会很贵。您可以只找到停止点,然后对用于查找停止点的验证数据进行一次额外的传递。
【讨论】:
谢谢! @ogrisel:验证数据一次通过就足够了吗?我如何检查它是否可以通过多次传递变得更好? 您可以将最终的考试成绩与您最初但成本更高的方法的考试成绩进行比较。 谢谢!并对这个琐碎的问题感到抱歉。这当然是要做的事情:)。【参考方案2】:有两种方式:
第一:
同时进行 x_train 和 x_test 拆分。您可以从 x_train 中拆分 0.1 并将其保留用于验证 x_dev:
x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=0.25)
x_train, x_dev, y_train, y_dev = train_test_split(x_train, y_train, test_size=0.1)
clf = GridSearchCV(YourEstimator(), param_grid=param_grid,)
clf.fit(x_train, y_train, x_dev, y_dev)
您的估算器将如下所示,并使用 x_dev、y_dev 实现提前停止
class YourEstimator(BaseEstimator, ClassifierMixin):
def __init__(self, param1, param2):
# perform initialization
#
def fit(self, x, y, x_dev=None, y_dev=None):
# perform training with early stopping
#
第二
您不会对 x_train 执行第二次拆分,而是会在 Estimator 的 fit 方法中取出开发集
x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=0.25)
clf = GridSearchCV(YourEstimator(), param_grid=param_grid)
clf.fit(x_train, y_train)
您的估算器将如下所示:
class YourEstimator(BaseEstimator, ClassifierMixin):
def __init__(self, param1, param2):
# perform initialization
#
def fit(self, x, y):
# perform training with early stopping
x_train, x_dev, y_train, y_dev = train_test_split(x, y,
test_size=0.1)
【讨论】:
以上是关于在 scikit-learn 中训练神经网络时提前停止的主要内容,如果未能解决你的问题,请参考以下文章
如何在 TensorFlow 1.4 中使用提前停止来训练深度神经网络?
Scikit-Learn:标签不是 x 出现在所有训练示例中
使用 scikit-learn 训练数据时,SVM 多类分类停止