文本分类 CNN 过拟合训练

Posted

技术标签:

【中文标题】文本分类 CNN 过拟合训练【英文标题】:Text classification CNN overfits training 【发布时间】:2020-10-05 00:14:31 【问题描述】:

我正在尝试使用 CNN 架构对文本句子进行分类。网络架构如下:

text_input = Input(shape=X_train_vec.shape[1:], name = "Text_input")

conv2 = Conv1D(filters=128, kernel_size=5, activation='relu')(text_input)
drop21 = Dropout(0.5)(conv2)
pool1 = MaxPooling1D(pool_size=2)(drop21)
conv22 = Conv1D(filters=64, kernel_size=5, activation='relu')(pool1)
drop22 = Dropout(0.5)(conv22)
pool2 = MaxPooling1D(pool_size=2)(drop22)
dense = Dense(16, activation='relu')(pool2)

flat = Flatten()(dense)
dense = Dense(128, activation='relu')(flat)
out = Dense(32, activation='relu')(dense)

outputs = Dense(y_train.shape[1], activation='softmax')(out)

model = Model(inputs=text_input, outputs=outputs)
# compile
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

我有一些回调作为 early_stopping 和 reduceLR 来停止训练并在验证损失没有改善(减少)时降低学习率。

early_stopping = EarlyStopping(monitor='val_loss', 
                               patience=5)
model_checkpoint = ModelCheckpoint(filepath=checkpoint_filepath,
                                   save_weights_only=False,
                                   monitor='val_loss',
                                   mode="auto",
                                   save_best_only=True)
learning_rate_decay = ReduceLROnPlateau(monitor='val_loss', 
                                        factor=0.1, 
                                        patience=2, 
                                        verbose=1, 
                                        mode='auto',
                                        min_delta=0.0001, 
                                        cooldown=0,
                                        min_lr=0)

训练模型后,训练的历史如下:

我们可以在这里观察到,从 epoch 5 开始,验证损失并没有改善,并且训练损失在每一步都被过度拟合。

我想知道我在 CNN 的架构中是否做错了什么?辍学层还不足以避免过度拟合吗?还有哪些其他方法可以减少过拟合?

有什么建议吗?

提前致谢。


编辑:

我也尝试过正则化,结果更糟:

kernel_regularizer=l2(0.01), bias_regularizer=l2(0.01)


编辑 2:

我尝试在每次卷积后应用 BatchNormalization 层,结果是下一个:

norm = BatchNormalization()(conv2)


编辑 3:

应用 LSTM 架构后:

text_input = Input(shape=X_train_vec.shape[1:], name = "Text_input")

conv2 = Conv1D(filters=128, kernel_size=5, activation='relu')(text_input)
drop21 = Dropout(0.5)(conv2)
conv22 = Conv1D(filters=64, kernel_size=5, activation='relu')(drop21)
drop22 = Dropout(0.5)(conv22)

lstm1 = Bidirectional(LSTM(128, return_sequences = True))(drop22)
lstm2 = Bidirectional(LSTM(64, return_sequences = True))(lstm1)

flat = Flatten()(lstm2)
dense = Dense(128, activation='relu')(flat)
out = Dense(32, activation='relu')(dense)

outputs = Dense(y_train.shape[1], activation='softmax')(out)

model = Model(inputs=text_input, outputs=outputs)
# compile
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

【问题讨论】:

最重要的问题是:“你的数据集有多大?”这似乎是一个非常小的数据集。如果是这种情况,您的主要反应应该是收集更多数据。 不小。我有 40000 个样本。 那肯定还有其他问题。由于验证损失几乎没有减少,它没有学到任何有用的东西。你有几节课?你的课程平衡吗? 有两个类,分布分别为58%、42% 嗯。没有更多信息很难猜测,但我仍然认为数据可能有问题。这是一个定制项目吗?你想预测什么? 【参考方案1】:

过拟合可能由多种因素引起,当您的模型与训练集拟合得太好时,就会发生这种情况。

要处理它,您可以采取一些方法:

    添加更多数据 使用数据增强 使用泛化能力强的架构 添加正则化(主要是dropout,L1/L2正则化也可以) 降低架构复杂性。

为了更清楚,您可以阅读https://towardsdatascience.com/deep-learning-3-more-on-cnns-handling-overfitting-2bd5d99abe5d

【讨论】:

【参考方案2】:

这是在尖叫迁移学习google-unversal-sentence-encoder 非常适合这个用例。将您的模型替换为

import tensorflow_hub as hub 
import tensorflow_text

text_input = Input(shape=X_train_vec.shape[1:], name = "Text_input")

# this next layer might need some tweaking dimension wise, to correctly fit
# X_train in the model
text_input = tf.keras.layers.Lambda(lambda x: tf.squeeze(x))(text_input)
# conv2 = Conv1D(filters=128, kernel_size=5, activation='relu')(text_input)
# drop21 = Dropout(0.5)(conv2)
# pool1 = MaxPooling1D(pool_size=2)(drop21)
# conv22 = Conv1D(filters=64, kernel_size=5, activation='relu')(pool1)
# drop22 = Dropout(0.5)(conv22)
# pool2 = MaxPooling1D(pool_size=2)(drop22)

# 1) you might need `text_input = tf.expand_dims(text_input, axis=0)` here
# 2) If you're classifying English only, you can use the link to the normal `google-universal-sentence-encoder`, not the multilingual one
# 3) both the English and multilingual have a `-large` version. More accurate but slower to train and infer. 
embedded = hub.KerasLayer('https://tfhub.dev/google/universal-sentence-encoder-multilingual/3')(text_input) 

# this layer seems out of place, 
# dense = Dense(16, activation='relu')(embedded) 

# you don't need to flatten after a dense layer (in your case) or a backbone (in my case (google-universal-sentence-encoder))
# flat = Flatten()(dense)

dense = Dense(128, activation='relu')(flat)
out = Dense(32, activation='relu')(dense)

outputs = Dense(y_train.shape[1], activation='softmax')(out)

model = Model(inputs=text_input, outputs=outputs)

【讨论】:

谢谢,我试一试,我会告诉你的。 这对您不起作用吗?如果没有,我将再次删除答案。 我在执行时遇到问题:ValueError: Shape must be rank 1 but is rank 2 for 'text_preprocessor/tokenize/StringSplit/StringSplit' (op: 'StringSplit') with input shapes: [?, 1],[]。 您是否在tf.KerasLayer 之前尝试过text_input = tf.expand_dims(text_input, axis=0)text_input = tf.squeeze(text_input) 两者都以这个错误结束:AttributeError: 'tuple' object has no attribute 'layer'【参考方案3】:

我认为,由于您正在进行文本分类,因此添加 1 或 2 个 LSTM 层可能有助于网络更好地学习,因为它将能够更好地与数据的上下文相关联。我建议在 flatten 层之前添加以下代码。

lstm1 = Bidirectional(LSTM(128, return_sequence = True))
lstm2 = Bidirectional(LSTM(64))

LSTM 层可以帮助神经网络学习某些单词之间的关联,并可能提高网络的准确性。

我还建议删除 Max Pooling 层,因为最大池化尤其是在文本分类中会导致网络丢弃一些有用的功能。 只保留卷积层和 dropout。还要在展平之前移除 Dense 层并添加上述 LSTM。

【讨论】:

我已经尝试过您提出的架构,但结果比仅使用卷积最差。验证损失直到第 12 个 epoch 才开始下降,一旦开始下降,它只提高了 1%,从 0.5754 到 0.5913。此外,val_loss 总是高于 train loss(准确性也更高)。我想这是一个奇怪的情况,我在问题中附上了这个数字作为编辑。 尝试添加嵌入层。 我尝试添加一个快速文本嵌入层。结果也很糟糕。【参考方案4】:

尚不清楚如何将文本输入模型。我假设您对文本进行标记以将其表示为整数序列,但是在将其输入模型之前,您是否使用了任何词嵌入?如果没有,我建议您在模型开始时抛出可训练的 tensorflow Embedding 层。有一种称为 Embedding Lookup 的巧妙技术可以加快其训练速度,但您可以将其保存以备后用。尝试将此层添加到您的模型中。然后你的Conv1D 层将更容易处理一系列浮点数。另外,我建议你在每个Conv1D 后面加上BatchNormalization,它应该有助于加快收敛和训练。

【讨论】:

我添加了一个快速文本嵌入层和批量标准化,但结果与其余部分相同。

以上是关于文本分类 CNN 过拟合训练的主要内容,如果未能解决你的问题,请参考以下文章

为啥我的 CNN 预训练图像分类器过拟合?

如何解决基于 NLP 的 CNN 模型中的过度拟合问题,以使用词嵌入进行多类文本分类?

拟合多标签文本分类模型时的错误

[Pytorch系列-61]:循环神经网络 - 中文新闻文本分类详解-3-CNN网络训练与评估代码详解

利用CNN进行图像分类的流程(猫狗大战为例)

这是过拟合的情况吗? CNN图像分类器