Keras Multi-class 多标签图像分类:处理独立和从属标签和非二进制输出的混合
Posted
技术标签:
【中文标题】Keras Multi-class 多标签图像分类:处理独立和从属标签和非二进制输出的混合【英文标题】:Keras Multi-class Multi-label image classification: handle a mix of independent and dependent labels & non-binary output 【发布时间】:2021-04-01 18:07:31 【问题描述】:我正在尝试训练来自 Keras 的预训练 VGG16 模型,用于多类多标签分类任务。这些图像来自 NIH 的胸部 X 射线 8 数据集。该数据集有 14 个标签(14 种疾病)加上一个“未发现”标签。
我知道对于独立标签,比如14种疾病,我应该使用sigmoid激活+binary_crossentropy损失函数;对于依赖标签,我应该使用 softmax + categorical_crossentropy。
但是,在我总共 15 个标签中,其中 14 个是独立的,但“未发现”与其余 14 个在技术上是相关的 --> “未发现”和患有疾病的概率应该加起来为 1,但应独立给出患有何种疾病的概率。那么我应该使用什么损失呢?
此外,我的输出是一个浮点数(概率)列表,每一列都是一个标签。
y_true:
[[0. 0. 0. ... 0. 0. 1.]
[0. 0. 0. ... 0. 0. 1.]
[0. 0. 1. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 1.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 1.]]
y_predict:
[[0.1749 0.0673 0.1046 ... 0. 0. 0.112 ]
[0. 0.1067 0.2804 ... 0. 0. 0.722 ]
[0. 0. 0.0686 ... 0. 0. 0.5373]
...
[0.0571 0.0679 0.0815 ... 0. 0. 0.532 ]
[0.0723 0.0555 0.2373 ... 0. 0. 0.4263]
[0.0506 0.1305 0.4399 ... 0. 0. 0.2792]]
这样的结果使得无法使用classification_report()
函数来评估我的模型。我正在考虑获得一个阈值以将其转换为二进制,但这将是更多的人工修改而不是 CNN 预测,因为我必须选择一个阈值。所以我不确定我是应该做一些硬编码的东西还是有其他已经存在的方法来处理这种情况?
我对 CNN 和分类很陌生,所以如果有人可以指导我或给我任何提示,我将非常感激。谢谢!
主体代码如下:
vgg16_model = VGG16()
last_layer = vgg16_model.get_layer('fc2').output
#I am treating them all as independent labels
out = Dense(15, activation='sigmoid', name='output_layer')(last_layer)
custom_vgg16_model = Model(inputs=vgg16_model.input, outputs=out)
for layer in custom_vgg16_model.layers[:-1]:
layer.trainable = False
custom_vgg16_model.compile(Adam(learning_rate=0.00001),
loss = "binary_crossentropy",
metrics = ['accuracy']) # metrics=accuracy gives me very good result,
# but I suppose it is due to the large amount
# of 0 label(not-this-disease prediction),
# therefore I am thinking to change it to
# recall and precision as metrics. If you have
# any suggestion on this I'd also like to hear!
【问题讨论】:
【参考方案1】:关于我的项目的一些更新,实际上我已经设法解决了这个问题中提到的大部分问题。
首先,由于这是一个多类多标签分类问题,我决定使用 ROC-AUC 分数而不是精确率或召回率作为评估指标。它的优点是不涉及阈值——AUC 有点像一系列阈值下的性能平均值。而且它只看正面预测,因此它减少了数据集中大多数 0 的影响。在我的案例中,这可以更准确地预测模型的性能。
对于输出类,我决定使用 14 个类而不是 15 个——如果所有标签都为 0,则表示“找不到”。然后我可以愉快地在我的输出层中使用 sigmoid 激活。尽管如此,我还是使用了焦点损失而不是二元交叉熵,因为我的数据集高度不平衡。
我仍然面临问题,因为我的 ROC 不好(非常接近 y=x,有时低于 y=x)。但我希望我的进步能给任何发现这一点的人一些启发。
【讨论】:
以上是关于Keras Multi-class 多标签图像分类:处理独立和从属标签和非二进制输出的混合的主要内容,如果未能解决你的问题,请参考以下文章