如何改进我的 CNN?高且持续的验证错误
Posted
技术标签:
【中文标题】如何改进我的 CNN?高且持续的验证错误【英文标题】:How to improve my CNN ? high and constant validation error 【发布时间】:2019-11-25 10:14:24 【问题描述】:我正在研究一个问题,根据奶牛的图像预测奶牛的肥胖程度。 我应用了一个 CNN 来估计介于 0-5 之间的值(我拥有的数据集,仅包含 2.25 和 4 之间的值) 我正在使用 4 个 CNN 层和 3 个隐藏层。
我实际上有两个问题: 1/ 我得到了 0.05 的训练误差,但是在 3-5 个 epoch 之后,验证误差保持在 0.33 左右。 2/ 我的 NN 预测的值在 2.9 到 3.3 之间,与数据集范围相比太窄了。正常吗?
如何改进我的模型?
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(512, 424,1)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Flatten(input_shape=(512, 424)),
tf.keras.layers.Dense(256, activation=tf.nn.relu),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(64, activation=tf.nn.relu),
tf.keras.layers.Dense(1, activation='linear')
])
学习曲线:
预测:
【问题讨论】:
【参考方案1】:这似乎是过拟合的情况。你可以
Shuffle
Data
,通过在cnn_model.fit
中使用shuffle=True
。代码如下:
history = cnn_model.fit(x = X_train_reshaped,
y = y_train,
batch_size = 512,
epochs = epochs, callbacks=[callback],
verbose = 1, validation_data = (X_test_reshaped, y_test),
validation_steps = 10, steps_per_epoch=steps_per_epoch, shuffle = True)
使用Early Stopping
。代码如下所示
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=15)
使用正则化。正则化代码如下(你可以试试l1正则化或者l1_l2正则化):
from tensorflow.keras.regularizers import l2
Regularizer = l2(0.001)
cnn_model.add(Conv2D(64,3, 3, input_shape = (28,28,1), activation='relu', data_format='channels_last',
activity_regularizer=Regularizer, kernel_regularizer=Regularizer))
cnn_model.add(Dense(units = 10, activation = 'sigmoid',
activity_regularizer=Regularizer, kernel_regularizer=Regularizer))
您可以尝试使用BatchNormalization
。
使用ImageDataGenerator
执行图像数据增强。有关详细信息,请参阅 this link。
如果像素不是Normalized
,将像素值除以255
也有帮助。
最后,如果还是没有变化,可以试试Pre-Trained Models
,比如ResNet
或者VGG Net
等。
【讨论】:
以上是关于如何改进我的 CNN?高且持续的验证错误的主要内容,如果未能解决你的问题,请参考以下文章