测试准确度为 98% 的模型的混淆矩阵不准确

Posted

技术标签:

【中文标题】测试准确度为 98% 的模型的混淆矩阵不准确【英文标题】:Inaccurate confusion matrix for a model with 98% test accuracy 【发布时间】:2019-01-11 14:58:29 【问题描述】:

我训练了一个二元分类模型,得到了 98% 的测试准确率和 99% 的训练准确率。

今天我想计算混淆矩阵并使用下面的代码来计算它们。

model = load_model("model.h5")

testGenerator = ImageDataGenerator(rotation_range=5,
                                width_shift_range=0.2,
                                height_shift_range=0.2,
                                horizontal_flip=False,
                                fill_mode='nearest'
                                )   

testData = testGenerator.flow_from_directory(
                                'Location', 
                                target_size=(74,448),                                                 
                                batch_size=15,
                                class_mode='binary',
                                shuffle=False
                                )

proba = model.predict_generator(testData,steps=3000//15)
y_true = np.array([0] * 1482 + [1] * 1482 )
y_pred = proba > 0.5
print(confusion_matrix(y_true, y_pred))

我收到了这个混淆矩阵:

正如 sklearn 所说:

这里说的假阴性和假阳性是如此之高。既然我有 98% 的测试准确率,这怎么可能呢?此外,我多次使用该模型生成预测(使用 model.predict() 函数)并手动检查它们。但每次它都给了我正确的分类。

任何想法如何获得准确的结果?

【问题讨论】:

你的真实数据真的像你的 y_true 变量一样分布吗? @CupinaCoffee 我已经设置了shuffle=false 来做到这一点。 好的。检查来自 soumendra 的评论 github.com/keras-team/keras/issues/3477 @CupinaCoffee 谢谢。我之前看过那篇文章,但他使用了 train_generator.class_indices,因为我已经训练了模型,所以我没有。 听起来您的初始模型可能在训练期间过度拟合。你能描述一下你训练模型的过程吗? 【参考方案1】:

让我们从头开始。消息“TypeError: unhashable type: 'numpy.ndarray'”表示您不能将numpy.ndarray 用作字典键,因为它不是不可变对象。首先将其转换为 tuple 或其他内容以使其不可变。

关于您的混淆矩阵,我敢打赌,生成器会以不可预知的顺序从文件夹中加载文件,但您仍然将 y_true 设置为 1482 zeros 和 1482 ones - 这可能与顺序匹配,也可能不匹配由生成器生成的文件。因此,您会得到有趣的结果。

【讨论】:

我最近修复了这个错误。但是来自 predict_generator() 的预测仍然不准确,因为 predict() 结果。另外,我设置了 shuffle=false。那它怎么会认为它们是不可预测的呢? @Sam94 好吧,那么不要使用predict_generator() 要生成混淆矩阵并进行统计分析,最好有大量样本,对吧? (至少 500 个)。但我只有 100 个这样的图像。所以我需要对它们进行扩充和预测。那么还有哪些其他可能的解决方案可供我使用? @Sam94 我不反对生成器,只是你的标签不适合数据集。根据您拥有的数据计算混淆矩阵。如果您获得明显更好的结果,请修正生成数据和基本事实的方式。 我使用了没有生成图像的图像并获得了准确的结果。我的意思是使用 predict() 函数。而不是使用 predict_generator()。所以我需要纠正这个问题。这就是我寻求帮助的原因

以上是关于测试准确度为 98% 的模型的混淆矩阵不准确的主要内容,如果未能解决你的问题,请参考以下文章

怎么计算混淆矩阵的消费者精度

如何使用混淆矩阵计算自定义训练的 spacy ner 模型的整体准确性?

29、评估多分类问题--混淆矩阵和F分数

分类模型评估:混淆矩阵准确率召回率ROC

混淆矩阵 - 测试情绪分析模型

R语言编写自定义函数计算分类模型评估指标:准确度特异度敏感度PPVNPV数据数据为模型预测后的混淆矩阵比较多个分类模型分类性能(逻辑回归决策树随机森林支持向量机)