python 实现 focal loss
Posted 一泓喜悲vv
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python 实现 focal loss相关的知识,希望对你有一定的参考价值。
cross entropy的缺点
cross entropy的表达式:
log(x) 与 -log(x) 的曲线图:
cross entropy 的两个缺点:
1. 数量多的类别会主导损失函数和梯度下降,导致模型更有信心预测数量多的类别,而缺少对数量少类别的重视。Balance cross entropy可以解决。
2. 模型无法分辨困难样本和简单样本。困难样本是指模型反复出错的样本,简单样本是容易被分类的样本。cross entropy无法关注到困难样本。
balance cross entropy
Balanced Cross-Entropy loss 为每个类增加了一个权重因子alpha ,范围[0, 1]。 Alpha 可以是逆类别频率或由交叉验证确定的超参数。 alpha 参数替换交叉熵方程中的实际标签项。
尽管解决了类别不平衡的问题,但仍然无法区分困难和简单的样本。 这个问题通过 focal loss 解决了。
focal loss
Focal loss 关注模型出错的样本,而不是它可以自信预测的样本,确保对困难样本的预测随着时间的推移而改进,而不是对简单的样本变得过于自信。
Focal loss 通过一种叫做 Down Weighting 的东西来实现这一点。 Down weighting 是一种减少简单示例对损失函数的影响的技术,从而将更多注意力放在困难示例上。 该技术可以通过向交叉熵损失添加调制因子来实现。
其中 γ (Gamma) 是要使用交叉验证调整的聚焦参数。 下图显示了 Focal Loss 对于不同的 γ 值的表现。
伽马参数如何工作?
1. 在错误分类样本的情况下,pi 很小,使得调制因子近似或非常接近 1。这使损失函数不受影响。 因此,它表现为交叉熵损失。
2. 随着模型置信度的增加,即 pi → 1,调制因子将趋于 0,从而降低分类良好示例的损失值的权重。 聚焦参数 γ ≥ 1 将重新调整调制因子,使简单示例比难示例权重降低更多,从而减少它们对损失函数的影响。 例如,假设预测概率为 0.9 和 0.6。 考虑到 γ = 2,为 0.9 计算的损失值为 4.5e-4 并按 100 倍加权,0.6 为 3.5e-2 并按 6.25 倍加权。 从实验来看,γ = 2 对 Focal Loss 论文的作者来说效果最好。
3. 当 γ = 0 时,Focal Loss 相当于 Cross Entropy。
在实践中,我们使用 α 平衡的 focal loss 变体,它继承了权重因子 α 和聚焦参数 γ 的特性,比非平衡形式的精度略好。
Focal Loss 自然地解决了类别不平衡的问题,因为大多数类别的示例通常很容易预测,而少数类别的示例由于缺乏主导损失和梯度过程的数据或多数类别的示例而难以预测。 由于这种相似性,Focal Loss 可能能够解决这两个问题。
python实现
def py_sigmoid_focal_loss(pred, target, weight=None, gamma=2.0, alpha=0.25, reduction=\'mean\', avg_factor=None): pred_sigmoid = pred.sigmoid() target = target.type_as(pred) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma) loss = F.binary_cross_entropy_with_logits(pred, target, reduction=\'none\') * focal_weight loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss
参考链接:https://towardsdatascience.com/focal-loss-a-better-alternative-for-cross-entropy-1d073d92d075
以上是关于python 实现 focal loss的主要内容,如果未能解决你的问题,请参考以下文章
Focal Loss 安装与使用 TensorFlow2.x版本
Focal Loss 安装与使用 TensorFlow2.x版本
Focal Loss 安装与使用 TensorFlow2.x版本