如何计算或绘制 CNN 中每个类的误差?
Posted
技术标签:
【中文标题】如何计算或绘制 CNN 中每个类的误差?【英文标题】:How can i calculate or plot error of each class in CNN? 【发布时间】:2019-11-16 15:30:39 【问题描述】:我将此代码用于我的项目,但我想计算或绘制每个类的错误。我有 6 个类。我该怎么办?
def plot_history(net_history):
history = network_history.history
losses = history['loss']
accuracies = history['acc']
plt.xlabel('Epochs')
plt.ylabel('loss')
plt.plot(losses)
plt.figure()
plt.xlabel('Epochs')
plt.ylabel('accuracy')
plt.plot(accuracies)
创建我的模型
myinput = layers.Input(shape=(100,200))
conv1 = layers.Conv1D(16, 3, activation='relu', padding='same', strides=2)(myinput)
conv2 = layers.Conv1D(32, 3, activation='relu', padding='same', strides=2)(conv1)
flat = layers.Flatten()(conv2)
out_layer = layers.Dense(6, activation='softmax')(flat)
mymodel = Model(myinput, out_layer)
mymodel.summary()
mymodel.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.categorical_crossentropy, metrics=['accuracy'])
训练我的模型
network_history = mymodel.fit(X_train, Y_train, batch_size=128,epochs=5, validation_split=0.2)
plot_history(network_history)
评估
test_loss, test_acc = mymodel.evaluate(X_test, Y_test)
test_labels_p = mymodel.predict(X_test)
【问题讨论】:
【参考方案1】:评估分类器的一种简单方法是 scikit-learn 中的classification_report
:
from sklearn.metrics import classification_report
....
# Actual predictions here, not just probabilities
pred = numpy.round(mymodel.predict(X_test))
print(classification_report(Y_test, pred))
其中Y_test
是一个单热向量列表。
这将显示每个类别的准确率、召回率和 f1 度量。缺点是它只考虑预测的正确与否,没有考虑模型的确定性。
【讨论】:
嗨,谢谢。但是当我使用这个代码时,我有这个问题:精确recallf1-score支持0 0.00 0.00 0.00 23 1 0.00 0.00 0.00 0 0.00 0.00 0.23 3 0.00 0.00 0.00 0.00 0.00 0.23 5 0.00 0.00 0.23 micro avg 0.00 0.00 0.00 138 宏平均 0.00 0.00 0.00 138 加权平均 0.00 0.00 0.00 138 个样本平均 0.00 0.00 0.00 138 这可能表明模型对任何类的输出概率都不超过 50%,在这种情况下,我在这里使用的简单方法根本不会预测任何类成员资格。这可能表明您的模型不够灵活,无法进行分类,或者您没有训练足够长的时间——5 个 epoch 并不多,而且测试集的小尺寸表明训练集也可能非常小(这也可能导致问题,具体取决于任务的复杂性)。你的输入数据是什么?【参考方案2】:你必须像二进制分类问题一样训练它,然后你可以使用这段代码为不同的类制作学习曲线:
plt.plot(network_history.history['loss'])
plt.plot(network_history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
【讨论】:
以上是关于如何计算或绘制 CNN 中每个类的误差?的主要内容,如果未能解决你的问题,请参考以下文章