为啥 Keras Early Stopping 功能不会停止训练,虽然监测值在增加?

Posted

技术标签:

【中文标题】为啥 Keras Early Stopping 功能不会停止训练,虽然监测值在增加?【英文标题】:Why does the Keras Early Stopping function not stop the training, although the monitored value is increasing?为什么 Keras Early Stopping 功能不会停止训练,虽然监测值在增加? 【发布时间】:2019-01-20 10:21:56 【问题描述】:

我正在尝试针对回归问题训练神经网络,并且我实现了 Keras Early Stopping Function 以避免过度拟合。

现在,当我监控“val_loss”时,提前停止功能几乎直接停止了程序,结果是一个无用的 NN,但是当我监控“val_mse”时,尽管我可以看到“val_mse”,但训练会继续进行而不会停止通过训练增加,我设置耐心 = 0。

我似乎误解了 Early Stopping Callback,因为我认为它确实会监控值并在值再次开始增加时立即停止训练。

np.random.seed(7)

#Define Input
tf_features_64 = np.load("IN_2.npy")
tf_labels_64 = np.load("OUT_2.npy")
tf_features_32 = tf_features_64.astype(np.float32)
tf_labels_32 = tf_labels_64.astype(np.float32)

X = tf_features_32
Y = tf_labels_32[0:10680, 4:8]
#Define Callback
tbCallBack = TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True) #TensorBoard Monitoring
esCallback = EarlyStopping(monitor='val_mse',
                           min_delta=0,
                           patience=0,
                           verbose=1,
                           mode='min')

#create Layers
visible = Input(shape=(33,))
x = Dropout(.1)(visible)
#x = Dense(63)(x)
#x = Dropout(.4)(x)
output = Dense(4)(x)  

Optimizer = optimizers.Adam(lr=0.001
                            #amsgrad = True)

model = Model(inputs=visible, outputs = output)
model.compile(optimizer=Optimizer,
              loss=['mse'],
              metrics=['mae', 'mse']
              )
model.fit(X, Y, epochs=8000, batch_size=20, shuffle=True, validation_split=0.35, callbacks=[tbCallBack, esCallback])

作为一个例子,我得到以下输出,我可以清楚地看到,val_mse 在各个时期增加。

  20/6942 [..............................] - ETA: 0s - loss: 0.0022 - mean_absolute_error: 0.0373 - mean_squared_error: 0.0022
1620/6942 [======>.......................] - ETA: 0s - loss: 0.0011 - mean_absolute_error: 0.0251 - mean_squared_error: 0.0011
3260/6942 [=============>................] - ETA: 0s - loss: 0.0015 - mean_absolute_error: 0.0290 - mean_squared_error: 0.0015
4900/6942 [====================>.........] - ETA: 0s - loss: 0.0017 - mean_absolute_error: 0.0301 - mean_squared_error: 0.0017
6500/6942 [===========================>..] - ETA: 0s - loss: 0.0016 - mean_absolute_error: 0.0301 - mean_squared_error: 0.0016
6942/6942 [==============================] - 0s 37us/step - loss: 0.0016 - mean_absolute_error: 0.0294 - mean_squared_error: 0.0016 - val_loss: 0.0011 - val_mean_absolute_error: 0.0240 - **val_mean_squared_error: 0.0011**
**Epoch 334/8000**

  20/6942 [..............................] - ETA: 0s - loss: 0.0025 - mean_absolute_error: 0.0367 - mean_squared_error: 0.0025
1620/6942 [======>.......................] - ETA: 0s - loss: 0.0012 - mean_absolute_error: 0.0257 - mean_squared_error: 0.0012
3260/6942 [=============>................] - ETA: 0s - loss: 0.0014 - mean_absolute_error: 0.0274 - mean_squared_error: 0.0014
4860/6942 [====================>.........] - ETA: 0s - loss: 0.0014 - mean_absolute_error: 0.0268 - mean_squared_error: 0.0014
6400/6942 [==========================>...] - ETA: 0s - loss: 0.0012 - mean_absolute_error: 0.0254 - mean_squared_error: 0.0012
6942/6942 [==============================] - 0s 39us/step - loss: 0.0012 - mean_absolute_error: 0.0249 - mean_squared_error: 0.0012 - val_loss: 0.0032 - val_mean_absolute_error: 0.0393 - **val_mean_squared_error: 0.0032**
**Epoch 335/8000**

  20/6942 [..............................] - ETA: 0s - loss: 9.5175e-04 - mean_absolute_error: 0.0243 - mean_squared_error: 9.5175e-04
1620/6942 [======>.......................] - ETA: 0s - loss: 0.0017 - mean_absolute_error: 0.0312 - mean_squared_error: 0.0017        
3260/6942 [=============>................] - ETA: 0s - loss: 0.0013 - mean_absolute_error: 0.0271 - mean_squared_error: 0.0013
4860/6942 [====================>.........] - ETA: 0s - loss: 0.0014 - mean_absolute_error: 0.0277 - mean_squared_error: 0.0014
6460/6942 [==========================>...] - ETA: 0s - loss: 0.0013 - mean_absolute_error: 0.0266 - mean_squared_error: 0.0013
6942/6942 [==============================] - 0s 38us/step - loss: 0.0013 - mean_absolute_error: 0.0268 - mean_squared_error: 0.0013 - val_loss: 0.0046 - val_mean_absolute_error: 0.0491 - **val_mean_squared_error: 0.0046**
**Epoch 336/8000**

【问题讨论】:

我认为val_mse 在 Keras 中没有任何内在意义,除非 Keras 在过去六个月中发生了显着变化。仅仅因为mse 是被识别的关键字,并不意味着val_mse 也被识别。为什么不使用val_loss 本身?如果你的损失是mse,那意味着val_loss = mse + regularization penalty。你应该使用val_loss 【参考方案1】:

val_mse 适用于 python 3.7 版,val_mean_squared_error 适用于 3.6 版

【讨论】:

【参考方案2】:

您的代码中没有名为 val_mse 的指标,您的回调正在监控错误的指标。有val_mean_squared_error,但与val_mse不同。

您应该将要监控的指标从 val_mse 更改为 val_mean_squared_error,它应该可以工作。

【讨论】:

以上是关于为啥 Keras Early Stopping 功能不会停止训练,虽然监测值在增加?的主要内容,如果未能解决你的问题,请参考以下文章

keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping)

带有“early_stopping_rounds”的 xgboost 的 cross_val_score 返回“IndexError”

早停法(Early Stopping)

XGBoost 与 GridSearchCV、缩放、PCA 和 sklearn 管道中的 Early-Stopping

如何找到训练 keras 模型的 epoch 数?

Early Stopping中基于测试集(而非验证集)上的表现选取模型的讨论