从 keras 生成器获取真实标签

Posted

技术标签:

【中文标题】从 keras 生成器获取真实标签【英文标题】:get true labels from keras generator 【发布时间】:2019-05-25 18:36:02 【问题描述】:

我想使用sklearn.metrics.confusion_matrix(y_true, y_pred) 为 keras 模型创建混淆矩阵。

训练模型后,我可以使用predict_generator(generator) 来获得测试数据集的预测,这给了我y_pred。如何从数据生成器中获取对应的真实标签y_true

【问题讨论】:

【参考方案1】:

generator.classes 将为您提供稀疏格式的观察值。您可能需要密集的(即一次性编码格式)。您可以通过以下方式获得:

import pandas as pd
pd.get_dummies(pd.Series(generator.classes)).to_dense()

注意:在生成预测和获取观察到的类之前,您必须将生成器的 shuffle 属性设置为 False,否则您的预测和观察将无法对齐!

【讨论】:

【参考方案2】:

创建数据生成器后,无论是您自己的还是内置的ImageDataGenerator,使用您经过训练的模型进行预测:

true_labels = data_generator.classes
predictions = model.predict_generator(data_generator)

sklearn 的混淆矩阵需要一维标签数组,因此您必须使用 np.argmax() 转换您的预测

y_true = true_labels
y_pred = np.array([np.argmax(x) for x in predictions])

然后你可以直接在confusion_matrix函数中使用这些变量

cm = sklearn.metrics.confusion_matrix(y_true, y_pred)

您可以使用此处的示例 plot_confusion_matrix() 函数对其进行绘制:

https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

【讨论】:

以上是关于从 keras 生成器获取真实标签的主要内容,如果未能解决你的问题,请参考以下文章

如何从生成器中获取字典输出,该生成器输出带有用于自定义 keras 图像生成器的字典的数组

与 Keras 的内置生成器相比,自定义 Keras 生成器要慢得多

Keras深度学习实战(22)——生成对抗网络详解与实现

我可以在 Keras 中使用 ImageDataGenerator() 和 flow_from_directory() 生成 uint8 标签吗?

数据增强图像数据生成器 Keras 语义分割

Keras搭建ACGAN生成MNIST手写体图片