如何找出概率输出中每列的哪个类对应于使用Keras进行多类分类?

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何找出概率输出中每列的哪个类对应于使用Keras进行多类分类?相关的知识,希望对你有一定的参考价值。

我正在使用转移学习在Keras中使用预先训练的VGG网络构建图像识别模型,并排除最终的完全连接层以获得输出权重。然后我使用这些输出权重进入我的新模型,该模型具有几个层以及我正在训练的新的全连接层。完全连接的层映射到我试图预测的输出类的数量。

一切都很好。但是,当我运行results = model.predict(img_tensor)时,我得到对应于每个类的输出概率,类似于下面的:

print(results)

 [[0.1426621  0.6193871  0.23795079]
 [0.11187755 0.6208466  0.2672758 ]
 [0.10050113 0.3768951  0.52260375]
 [0.1338948  0.59470254 0.27140263]
 [0.06612041 0.69726    0.2366195 ]
 [0.12080433 0.495977   0.38321865]]

我的问题是:如何找出概率输出中每列的哪个类对应?

Keras是否内置任何内容来确定输出概率的哪一列与哪个类相对应?如果没有提供,我会感到震惊......

其他人做了什么来创造一个解决方案?

谢谢!

答案

Keras的flow_from_directory(目录)中的“class_indices”属性在输出数组中创建类及其索引的字典:

classes:类子目录的可选列表(例如['dogs','cats'])。默认值:无。如果未提供,则将从目录下的子目录名称/结构自动推断类列表,其中每个子目录将被视为不同的类(并且将映射到标签索引的类的顺序将是字母数字)。包含从类名到类索引的映射的字典可以通过属性class_indices获得。

资料来源:https://keras.io/preprocessing/image/

我很想知道人们如何将这个结合到他们的工作流程/脚本中...

以上是关于如何找出概率输出中每列的哪个类对应于使用Keras进行多类分类?的主要内容,如果未能解决你的问题,请参考以下文章

获取二维数组中每列的第二个最小值

如何控制 flexbox 中每列的项目数?

计算Spark DataFrame中每列的内核密度

如何列出各个列,其中每个列包含一个 id 计数,其中每列中的 id 不在 MySQL 中每列的不同表中

如何获得每列的最大值?

计算循环中每列的中位数