如何计算或绘制 CNN 中每个类的误差?

Posted

技术标签:

【中文标题】如何计算或绘制 CNN 中每个类的误差?【英文标题】:How can i calculate or plot error of each class in CNN? 【发布时间】:2019-11-16 15:30:39 【问题描述】:

我将此代码用于我的项目,但我想计算或绘制每个类的错误。我有 6 个类。我该怎么办?

def plot_history(net_history):
    history = network_history.history
    losses = history['loss']
    accuracies = history['acc']
    plt.xlabel('Epochs')
    plt.ylabel('loss')
    plt.plot(losses)

    plt.figure()
    plt.xlabel('Epochs')
    plt.ylabel('accuracy')
    plt.plot(accuracies)

创建我的模型

myinput = layers.Input(shape=(100,200))
conv1 = layers.Conv1D(16, 3, activation='relu', padding='same', strides=2)(myinput)
conv2 = layers.Conv1D(32, 3, activation='relu', padding='same', strides=2)(conv1)
flat = layers.Flatten()(conv2)
out_layer = layers.Dense(6, activation='softmax')(flat)

mymodel = Model(myinput, out_layer)
mymodel.summary()
mymodel.compile(optimizer=keras.optimizers.Adam(), 
loss=keras.losses.categorical_crossentropy, metrics=['accuracy'])

训练我的模型

network_history = mymodel.fit(X_train, Y_train, batch_size=128,epochs=5, validation_split=0.2)
plot_history(network_history)

评估

test_loss, test_acc = mymodel.evaluate(X_test, Y_test)

test_labels_p = mymodel.predict(X_test)

【问题讨论】:

【参考方案1】:

评估分类器的一种简单方法是 scikit-learn 中的classification_report

from sklearn.metrics import classification_report

....

# Actual predictions here, not just probabilities
pred = numpy.round(mymodel.predict(X_test))
print(classification_report(Y_test, pred))

其中Y_test 是一个单热向量列表。

这将显示每个类别的准确率、召回率和 f1 度量。缺点是它只考虑预测的正确与否,没有考虑模型的确定性。

【讨论】:

嗨,谢谢。但是当我使用这个代码时,我有这个问题:精确recallf1-score支持0 0.00 0.00 0.00 23 1 0.00 0.00 0.00 0 0.00 0.00 0.23 3 0.00 0.00 0.00 0.00 0.00 0.23 5 0.00 0.00 0.23 micro avg 0.00 0.00 0.00 138 宏平均 0.00 0.00 0.00 138 加权平均 0.00 0.00 0.00 138 个样本平均 0.00 0.00 0.00 138 这可能表明模型对任何类的输出概率都不超过 50%,在这种情况下,我在这里使用的简单方法根本不会预测任何类成员资格。这可能表明您的模型不够灵活,无法进行分类,或者您没有训练足够长的时间——5 个 epoch 并不多,而且测试集的小尺寸表明训练集也可能非常小(这也可能导致问题,具体取决于任务的复杂性)。你的输入数据是什么?【参考方案2】:

你必须像二进制分类问题一样训练它,然后你可以使用这段代码为不同的类制作学习曲线:

plt.plot(network_history.history['loss'])
plt.plot(network_history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

【讨论】:

以上是关于如何计算或绘制 CNN 中每个类的误差?的主要内容,如果未能解决你的问题,请参考以下文章

三种工具绘制errorbar图

seaborn库中柱状图绘制详解

如何使用 pyplot.bar 仅绘制正误差条?

如何基于cnn检测物体并绘制边界框

如何使用误差线和 p 值绘制每天的治疗平均值?

如何使用 python 绘制整个音频文件的频谱或频率与幅度的关系?