focal loss 之 pytorch 实现

Posted 泯灭XzWz

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了focal loss 之 pytorch 实现相关的知识,希望对你有一定的参考价值。

# coding=utf-8
import torch
import torch.nn.functional as F

from torch import nn
from torch.nn import CrossEntropyLoss
import numpy as np

class MultiFocalLoss(nn.Module):
    """
    Focal_Loss= -1*alpha*((1-pt)**gamma)*log(pt)
    Args:
        num_class: number of classes
        alpha: class balance factor shape=[num_class, ]
        gamma: hyper-parameter
        reduction: reduction type
    """

    def __init__(self, num_class, alpha=None, gamma=2, reduction='mean'):
        super(MultiFocalLoss, self).__init__()
        self.num_class = num_class
        self.gamma = gamma
        self.reduction = reduction
        self.smooth = 1e-4
        self.alpha = alpha
        if alpha is None:
            self.alpha = torch.ones(num_class, ) - 0.5
        elif isinstance(alpha, (int, float)):
            self.alpha = torch.as_tensor([alpha] * num_class)
        elif isinstance(alpha, (list, np.ndarray)):
            self.alpha = torch.as_tensor(alpha)
        if self.alpha.shape[0] != num_class:
            raise RuntimeError('the length not equal to number of class')

    def forward(self, logit, target):
        """
        N: batch size C: class num
        :param logit: [N, C] 或者 [N, C, d1, d2, d3 ......]
        :param target: [N] 或 [N, d1, d2, d3 ........]
        :return: 
        """
        # assert isinstance(self.alpha,torch.Tensor)\\
        alpha = self.alpha.to(logit.device)
        prob = F.softmax(logit, dim=1)

        if prob.dim() > 2:
            # used for 3d-conv:  N,C,d1,d2 -> N,C,m (m=d1*d2*...)
            N, C = logit.shape[:2]
            prob = prob.view(N, C, -1)
            prob = prob.transpose(1, 2).contiguous()  # [N,C,d1*d2..] -> [N,d1*d2..,C]
            prob = prob.view(-1, prob.size(-1))  # [N,d1*d2..,C]-> [N*d1*d2..,C]

        ori_shp = target.shape
        target = target.view(-1, 1)

        prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nan
        logpt = torch.log(prob)
        # alpha_class = alpha.gather(0, target.squeeze(-1))
        alpha_weight = alpha[target.squeeze().long()]
        loss = -alpha_weight * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'none':
            loss = loss.view(ori_shp)

        return loss

if __name__ == "__main__":
    batch_size, seq_len, num_class = 1, 2, 3
    
    # 二维
    Loss_Func = MultiFocalLoss(num_class=num_class, alpha=1, gamma=2, reduction='mean')
    logits = torch.rand(batch_size, num_class, requires_grad=True)  # (batch_size, num_classes)
    targets = torch.randint(0, num_class, size=(batch_size,))  # (batch_size, )
    loss = Loss_Func(logits, targets)
    print(loss)
    loss.backward()


    # 多维
    Loss_Func = MultiFocalLoss(num_class=num_class, gamma=2.0, reduction='mean')
    logits = torch.rand(batch_size, seq_len, num_class, requires_grad=True)  # (batch_size, num_classes)
    targets = torch.randint(0, num_class, size=(batch_size, seq_len))  # (batch_size, )

    loss = Loss_Func(logits.permute(0,2,1), targets)  # 类别必须放在第二个维度
    print(loss)
    loss.backward()

以上是关于focal loss 之 pytorch 实现的主要内容,如果未能解决你的问题,请参考以下文章

损失函数解读 之 Focal Loss

损失函数解读 之 Focal Loss

python 实现 focal loss

CTC Loss和Focal CTC Loss

Focal Loss

Keras 自定义loss函数 focal loss + triplet loss