为 CNN 模型实现交叉验证

Posted

技术标签:

【中文标题】为 CNN 模型实现交叉验证【英文标题】:Implementing Cross Validation for CNN model 【发布时间】:2020-06-01 06:42:54 【问题描述】:

我已经建立了我的 CNN 模型来对 8 个类别的图像进行分类。训练和测试步骤是通过随机拆分 80% 用于训练图像和 20% 用于测试图像来完成的,其中计算了 Acuuracy 和 F-measure 结果。

我注意到,与我的测试结果相比,我的训练准确度结果略低(我认为训练准确度应该更高)。经过大量搜索,我找到了两个原因:

1- dropput(0.5)的使用,影响训练准确率结果

2-测试数据集分类简单。

我计划通过进行 10-k 交叉验证来评估我的 CNN 分类器。但是,由于我还是这个领域的新手,所以我找到的大多数答案都是针对 .csv 文件的,我有图像文件。

如何编写代码以获得交叉验证?

我可以得到一个混淆矩阵来进行交叉验证吗?

我的代码:

from keras.models import Sequential
from keras.layers import Conv2D,Activation,MaxPooling2D,Dense,Flatten,Dropout
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from IPython.display import display
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix
import keras
from keras.layers import BatchNormalization
from keras.optimizers import Adam

classifier = Sequential()
classifier.add(Conv2D(32,(3,3),input_shape=(200,200,3)))
classifier.add(Activation('relu'))
classifier.add(MaxPooling2D(pool_size =(2,2)))
classifier.add(Flatten())
classifier.add(Dense(128))
classifier.add(Activation('relu'))
classifier.add(Dropout(0.5))
classifier.add(Dense(8))
classifier.add(Activation('softmax'))
classifier.summary()
classifier.compile(optimizer =keras.optimizers.Adam(lr=0.001),
                   loss ='categorical_crossentropy',
                   metrics =['accuracy'])
train_datagen = ImageDataGenerator(rescale =1./255,
                                   shear_range =0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip =True)
test_datagen = ImageDataGenerator(rescale = 1./255)

batchsize=10
training_set = train_datagen.flow_from_directory('/home/osboxes/Downloads/Downloads/Journal_Paper/Train/',
                                                target_size=(200,200),
                                                batch_size= batchsize,
                                                class_mode='categorical')

test_set = test_datagen.flow_from_directory('/home/osboxes/Downloads/Downloads/Journal_Paper/Test/',
                                           target_size = (200,200),
                                           batch_size = batchsize,
                       shuffle=False,
                                           class_mode ='categorical')
history=classifier.fit_generator(training_set,
                        steps_per_epoch = 3067 // batchsize,
                        epochs = 50,
                        validation_data =test_set,
                        validation_steps = 769 // batchsize)


Y_pred = classifier.predict_generator(test_set, steps= 769 // batchsize + 1)
y_pred = np.argmax(Y_pred, axis=1)
print('Confusion Matrix')
print(confusion_matrix(test_set.classes, y_pred))
print('Classification Report')
target_names = test_set.classes
class_labels = list(test_set.class_indices.keys()) 
target_names = ['coinhive','emotent','fareit', 'flystudio', 'gafgyt','gamarue', 'mirai','razy'] 
report = classification_report(test_set.classes, y_pred, target_names=class_labels)
print(report) 

# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

【问题讨论】:

【参考方案1】:

根据Wikipedia

In k-fold cross-validation, the original sample is randomly partitioned into k equal sized subsamples. Of the k subsamples, a single subsample is retained as the validation data for testing the model, and the remaining k − 1 subsamples are used as training data.

例如,像你想做 10 折交叉验证,你会

    将数据随机分成十个相等的部分

    用九个数据子集训练十个独立模型,

    验证每个模型的不同数据子集 分别找出每个模型的混淆矩阵

训练和测试的代码将保持不变,只是可以通过自定义生成器或使用 SKLearn 的KFold 获取数据输入。

【讨论】:

以上是关于为 CNN 模型实现交叉验证的主要内容,如果未能解决你的问题,请参考以下文章

字符识别--模型集成

字符识别--模型集成

交叉验证的实现

为啥交叉验证 RF 分类的性能比没有交叉验证的差?

5倍交叉验证如何理解

正则化交叉验证泛化能力