focal loss 之 pytorch 实现
Posted 泯灭XzWz
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了focal loss 之 pytorch 实现相关的知识,希望对你有一定的参考价值。
# coding=utf-8
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
import numpy as np
class MultiFocalLoss(nn.Module):
"""
Focal_Loss= -1*alpha*((1-pt)**gamma)*log(pt)
Args:
num_class: number of classes
alpha: class balance factor shape=[num_class, ]
gamma: hyper-parameter
reduction: reduction type
"""
def __init__(self, num_class, alpha=None, gamma=2, reduction='mean'):
super(MultiFocalLoss, self).__init__()
self.num_class = num_class
self.gamma = gamma
self.reduction = reduction
self.smooth = 1e-4
self.alpha = alpha
if alpha is None:
self.alpha = torch.ones(num_class, ) - 0.5
elif isinstance(alpha, (int, float)):
self.alpha = torch.as_tensor([alpha] * num_class)
elif isinstance(alpha, (list, np.ndarray)):
self.alpha = torch.as_tensor(alpha)
if self.alpha.shape[0] != num_class:
raise RuntimeError('the length not equal to number of class')
def forward(self, logit, target):
"""
N: batch size C: class num
:param logit: [N, C] 或者 [N, C, d1, d2, d3 ......]
:param target: [N] 或 [N, d1, d2, d3 ........]
:return:
"""
# assert isinstance(self.alpha,torch.Tensor)\\
alpha = self.alpha.to(logit.device)
prob = F.softmax(logit, dim=1)
if prob.dim() > 2:
# used for 3d-conv: N,C,d1,d2 -> N,C,m (m=d1*d2*...)
N, C = logit.shape[:2]
prob = prob.view(N, C, -1)
prob = prob.transpose(1, 2).contiguous() # [N,C,d1*d2..] -> [N,d1*d2..,C]
prob = prob.view(-1, prob.size(-1)) # [N,d1*d2..,C]-> [N*d1*d2..,C]
ori_shp = target.shape
target = target.view(-1, 1)
prob = prob.gather(1, target).view(-1) + self.smooth # avoid nan
logpt = torch.log(prob)
# alpha_class = alpha.gather(0, target.squeeze(-1))
alpha_weight = alpha[target.squeeze().long()]
loss = -alpha_weight * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'none':
loss = loss.view(ori_shp)
return loss
if __name__ == "__main__":
batch_size, seq_len, num_class = 1, 2, 3
# 二维
Loss_Func = MultiFocalLoss(num_class=num_class, alpha=1, gamma=2, reduction='mean')
logits = torch.rand(batch_size, num_class, requires_grad=True) # (batch_size, num_classes)
targets = torch.randint(0, num_class, size=(batch_size,)) # (batch_size, )
loss = Loss_Func(logits, targets)
print(loss)
loss.backward()
# 多维
Loss_Func = MultiFocalLoss(num_class=num_class, gamma=2.0, reduction='mean')
logits = torch.rand(batch_size, seq_len, num_class, requires_grad=True) # (batch_size, num_classes)
targets = torch.randint(0, num_class, size=(batch_size, seq_len)) # (batch_size, )
loss = Loss_Func(logits.permute(0,2,1), targets) # 类别必须放在第二个维度
print(loss)
loss.backward()
以上是关于focal loss 之 pytorch 实现的主要内容,如果未能解决你的问题,请参考以下文章