避免过度拟合使用 keras 的序列问题
Posted
技术标签:
【中文标题】避免过度拟合使用 keras 的序列问题【英文标题】:Avoid overfitting in sequence to sequence problem using keras 【发布时间】:2019-06-03 02:35:39 【问题描述】:我想训练的模型有问题。
这是一个典型的带有注意力层的 seq-to-seq 问题,其中输入是字符串,输出是提交字符串的子字符串。
例如
Input Ground Truth
-----------------------------
helloimchuck chuck
johnismyname john
(这只是一个虚拟数据,不是数据集的真实部分^^)
模型看起来像这样:
model = Sequential()
model.add(Bidirectional(GRU(hidden_size, return_sequences=True), merge_mode='concat',
input_shape=(None, input_size))) # Encoder
model.add(Attention())
model.add(RepeatVector(max_out_seq_len))
model.add(GRU(hidden_size * 2, return_sequences=True)) # Decoder
model.add(TimeDistributed(Dense(units=output_size, activation="softmax")))
model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=['accuracy'])
问题出在这里:
如您所见,存在过拟合。
我使用patience=8
对验证损失使用提前停止标准。
self.Early_stop_criteria = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
patience=8, verbose=0,
mode='auto')
我正在使用 one-hot-vector 来拟合模型。
BATCH_SIZE = 64
HIDDEN_DIM = 128
问题是,我尝试了其他批量大小、其他隐藏维度、10K 行、15K 行、25K 行和现在 50K 行的数据集。但是,总是存在过度拟合,我不知道为什么。
test_size = 0.2
和 validation_split=0.2
。这些是我没有改变的唯一参数。
我还确保数据集正确构建。
我唯一的想法是尝试另一个验证拆分,可能是0.33
而不是0.2
。
我不知道cross-validation
是否会有所帮助。
也许有人有更好的主意,我可以尝试一下。提前致谢。
【问题讨论】:
你试过dropout或者batchnorm吗? 并非如此。你会在输入和第一个隐藏层之间使用它吗? 您可以使用 2 种类型的 dropout。请参阅 this answer 和 this answer 了解更多详细信息,了解它们的区别以及如何将它们与 Keras 一起使用以及我们可以在哪里使用它们:) 【参考方案1】:正如 kvish 建议的那样,dropout 是一个很好的解决方案。
我首先尝试使用 0.2 的 dropout。
model = Sequential()
model.add(Bidirectional(GRU(hidden_size, return_sequences=True, dropout=0.2), merge_mode='concat',
input_shape=(None, input_size))) # Encoder
model.add(Attention())
model.add(RepeatVector(max_out_seq_len))
model.add(GRU(hidden_size * 2, return_sequences=True)) # Decoder
model.add(TimeDistributed(Dense(units=output_size, activation="softmax")))
model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=['accuracy'])
对于 50K 行,它可以工作,但仍然存在过度拟合。
所以,我尝试了 0.33 的 dropout,效果很好。
【讨论】:
以上是关于避免过度拟合使用 keras 的序列问题的主要内容,如果未能解决你的问题,请参考以下文章