图像分割损失函数OhemCELoss

Posted 超级无敌陈大佬的跟班

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了图像分割损失函数OhemCELoss相关的知识,希望对你有一定的参考价值。

OhemCELoss函数简介

OhemCELoss函数( Online hard example mining cross-entropy loss 的缩写)

分割任务中的OhemCELoss函数:其实就是分类任务的交叉熵函数—>每个像素计算分类交叉熵---->根据loss选取难样本,一步一步扩展得到。

在语义分割网络中常用的损失函数,这里大概记录几个需要留意的点:

1)计算交叉熵损失时,是以一个像素点为计算单位,计算出每个像素点的交叉熵分类损失。

2)ohem难样本挖掘时,根据给定的阈值选取前n_min个像素点的loss值。

源码:

使用pytorch框架OhemCELoss函数的代码实现

class OhemCELoss(nn.Module):
    """
    Online hard example mining cross-entropy loss:在线难样本挖掘
    if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
    如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
    那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
    否则,计算前 n_min 个损失:loss = loss[:self.n_min]
    """
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()     # 将输入的概率 转换为loss值
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')   #交叉熵

    def forward(self, logits, labels):
        N, C, H, W = logits.size()
        loss = self.criteria(logits, labels).view(-1)
        loss, _ = torch.sort(loss, descending=True)     # 排序
        if loss[self.n_min] > self.thresh:       # 当loss大于阈值(由输入概率转换成loss阈值)的像素数量比n_min多时,取所以大于阈值的loss值
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]
        return torch.mean(loss)

一篇讲述很详细的博客

详细的内容直接参见下面的博客就好,不必重复码字造车。

图像分割 OhemCELoss - 简书https://www.jianshu.com/p/24376b18e5c7

以上是关于图像分割损失函数OhemCELoss的主要内容,如果未能解决你的问题,请参考以下文章

图像分割损失函数OhemCELoss

图像分割 - Keras 中的自定义损失函数

语义分割损失函数

在训练 CNN 进行图像分割时,我的损失怎么会突然增加?

Pytorch 语义分割损失函数

PyTorch使用交叉熵作为语义分割损失函数遇到的坑