图像分割损失函数OhemCELoss
Posted 超级无敌陈大佬的跟班
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了图像分割损失函数OhemCELoss相关的知识,希望对你有一定的参考价值。
OhemCELoss函数简介
OhemCELoss函数( Online hard example mining cross-entropy loss 的缩写)
分割任务中的OhemCELoss函数:其实就是分类任务的交叉熵函数—>每个像素计算分类交叉熵---->根据loss选取难样本,一步一步扩展得到。
在语义分割网络中常用的损失函数,这里大概记录几个需要留意的点:
1)计算交叉熵损失时,是以一个像素点为计算单位,计算出每个像素点的交叉熵分类损失。
2)ohem难样本挖掘时,根据给定的阈值选取前n_min个像素点的loss值。
源码:
使用pytorch框架OhemCELoss函数的代码实现
class OhemCELoss(nn.Module):
"""
Online hard example mining cross-entropy loss:在线难样本挖掘
if loss[self.n_min] > self.thresh: 最少考虑 n_min 个损失最大的 pixel,
如果前 n_min 个损失中最小的那个的损失仍然大于设定的阈值,
那么取实际所有大于该阈值的元素计算损失:loss=loss[loss>thresh]。
否则,计算前 n_min 个损失:loss = loss[:self.n_min]
"""
def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
super(OhemCELoss, self).__init__()
self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() # 将输入的概率 转换为loss值
self.n_min = n_min
self.ignore_lb = ignore_lb
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') #交叉熵
def forward(self, logits, labels):
N, C, H, W = logits.size()
loss = self.criteria(logits, labels).view(-1)
loss, _ = torch.sort(loss, descending=True) # 排序
if loss[self.n_min] > self.thresh: # 当loss大于阈值(由输入概率转换成loss阈值)的像素数量比n_min多时,取所以大于阈值的loss值
loss = loss[loss>self.thresh]
else:
loss = loss[:self.n_min]
return torch.mean(loss)
一篇讲述很详细的博客
详细的内容直接参见下面的博客就好,不必重复码字造车。
图像分割 OhemCELoss - 简书https://www.jianshu.com/p/24376b18e5c7
以上是关于图像分割损失函数OhemCELoss的主要内容,如果未能解决你的问题,请参考以下文章