具有 F1 分数的 Keras 多标签图像分类
Posted
技术标签:
【中文标题】具有 F1 分数的 Keras 多标签图像分类【英文标题】:Keras multi-label image classification with F1-score 【发布时间】:2019-10-12 10:23:19 【问题描述】:我正在研究一个multi-label
图像分类问题,并根据系统预测标签和地面实况标签之间的F1-score
进行评估。
鉴于此,我应该使用loss="binary_crossentropy"
还是loss=keras_metrics.f1_score()
,其中keras_metrics.f1_score()
取自这里:https://pypi.org/project/keras-metrics/
?我有点困惑,因为我在网上找到的所有关于multi-label
分类的教程都是基于binary_crossentropy
损失函数,但这里我必须针对F1-score
进行优化。
此外,我应该设置metrics=["accuracy"]
还是metrics=[keras_metrics.f1_score()]
,或者我应该将其完全留空?
【问题讨论】:
你应该使用f1_score
作为度量值,而不是损失函数。 损失函数,用于模型的训练以指导优化过程,和我们用来理解的(人类可解释的)metrics之间是有区别的模型的性能(即准确性)。更重要的一点是,损失函数通常应该是可微的,而大多数使用的度量函数都不是这种情况。 This answer 也可能会有所帮助。
所以,只是为了确认一下:我可以使用来自keras_metrics
包的f1_score
作为metrics=
(人类interpretable
),但对于loss=
,我应该使用differentiable
函数。那么,F1-score
的以下differentiable
版本呢? https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
我现在无法验证该损失函数,但如果它是 1) 可微分和 2) 最小化它意味着模型具有更高的准确度(即更高的度量值),那么一切都很好,你可以用它。另一点:确实,在 Keras 中使用度量 'accuracy'
来解决 多标签 分类问题可能会给你一个错误的信号,尤其是当唯一标签的数量很高时,因为它给出的值非常高值,因此您可能认为您的模型非常做得很好,但实际上可能并非如此。这就是为什么f1_score
是一个更好的指标。
"所以我猜 'accuracy' 不是在这里使用的正确 loss 函数,所以我绝对应该使用 f1_score。": 是的,但我猜你的意思是正确的度量函数;)
这几乎是正确的,唯一的例外是:“我必须找到/实现'1 - F1-score'的可微版本并将其用作损失函数” .实际上,这样做没有必要性,因为binary_crossentropy
可能也有效。但是,如果您能找到一个以某种方式直接以 f1-score 为目标的损失函数,并且在您的实验中它的性能优于 binary_crossentropy
,那么一切都很好,您可以使用它来代替。但我建议您首先开始尝试最容易获得的选项:binary_crossentropy
。
【参考方案1】:
基于user706838的回答...
使用https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric中的f1_score
import tensorflow as tf
import keras.backend as K
def f1_loss(y_true, y_pred):
tp = K.sum(K.cast(y_true*y_pred, 'float'), axis=0)
tn = K.sum(K.cast((1-y_true)*(1-y_pred), 'float'), axis=0)
fp = K.sum(K.cast((1-y_true)*y_pred, 'float'), axis=0)
fn = K.sum(K.cast(y_true*(1-y_pred), 'float'), axis=0)
p = tp / (tp + fp + K.epsilon())
r = tp / (tp + fn + K.epsilon())
f1 = 2*p*r / (p+r+K.epsilon())
f1 = tf.where(tf.is_nan(f1), tf.zeros_like(f1), f1)
return 1 - K.mean(f1)
【讨论】:
以上是关于具有 F1 分数的 Keras 多标签图像分类的主要内容,如果未能解决你的问题,请参考以下文章