如何在 PyTorch 中计算自举交叉熵损失?

Posted

技术标签:

【中文标题】如何在 PyTorch 中计算自举交叉熵损失?【英文标题】:How do I compute bootstrapped cross entropy loss in PyTorch? 【发布时间】:2020-12-23 09:03:28 【问题描述】:

我读过一些论文,它们使用“自举交叉熵损失”来训练他们的分割网络。我们的想法是只关注最难的 k%(比如 15%)像素,以提高学习性能,尤其是在容易的像素占主导地位时。

目前,我使用的是标准交叉熵:

loss = F.binary_cross_entropy(mask, gt)

如何在 PyTorch 中有效地将其转换为引导版本?

【问题讨论】:

【参考方案1】:

添加到@hkchengrex 的自我回答(用于将来的自我和 API 与 PyTorch 的对等);

可以像这样首先实现functional 版本(在original torch.nn.functional.cross_entropy 中提供一些额外的参数)(我也更喜欢reductioncallable 而不是预定义的字符串):

import typing

import torch


def bootstrapped_cross_entropy(
    inputs,
    targets,
    iteration,
    p: float,
    warmup: typing.Union[typing.Callable[[float, int], float], int] = -1,
    weight=None,
    ignore_index=-100,
    reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean,
):
    if not 0 < p < 1:
        raise ValueError("p should be in [0, 1] range, got: ".format(p))

    if isinstance(warmup, int):
        this_p = 1.0 if iteration < warmup else p
    elif callable(warmup):
        this_p = warmup(p, iteration)
    else:
        raise ValueError(
            "warmup should be int or callable, got ".format(type(warmup))
        )

    # Shortcut
    if this_p == 1.0:
        return torch.nn.functional.cross_entropy(
            inputs, targets, weight, ignore_index=ignore_index, reduction=reduction
        )

    raw_loss = torch.nn.functional.cross_entropy(
        inputs, targets, weight=weight, ignore_index=ignore_index, reduction="none"
    ).view(-1)
    num_pixels = raw_loss.numel()

    loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
    return reduction(loss)

还可以将warmup 指定为callable(采用p 和当前iteration)或int,这允许灵活或轻松的调度。

并在每次调用期间自动递增 _WeightedLossiteration 的类(因此只有 inputstargets 必须通过):

class BoostrappedCrossEntropy(torch.nn.modules.loss._WeightedLoss):
    def __init__(
        self,
        p: float,
        warmup: typing.Union[typing.Callable[[float, int], float], int] = -1,
        weight=None,
        ignore_index=-100,
        reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean,
    ):
        self.p = p
        self.warmup = warmup
        self.ignore_index = ignore_index
        self._current_iteration = -1

        super().__init__(weight, size_average=None, reduce=None, reduction=reduction)

    def forward(self, inputs, targets):
        self._current_iteration += 1
        return bootstrapped_cross_entropy(
            inputs,
            targets,
            self._current_iteration,
            self.p,
            self.warmup,
            self.weight,
            self.ignore_index,
            self.reduction,
        )

【讨论】:

【参考方案2】:

通常我们还会在损失中添加一个“热身”期,以便网络可以学习先适应容易的区域并过渡到较难的区域。

此实现从 k=100 开始并持续 20000 次迭代,然后线性衰减到 k=15 再进行 50000 次迭代。

class BootstrappedCE(nn.Module):
    def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15):
        super().__init__()

        self.start_warm = start_warm
        self.end_warm = end_warm
        self.top_p = top_p

    def forward(self, input, target, it):
        if it < self.start_warm:
            return F.cross_entropy(input, target), 1.0

        raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
        num_pixels = raw_loss.numel()

        if it > self.end_warm:
            this_p = self.top_p
        else:
            this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
        loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
        return loss.mean(), this_p

【讨论】:

以上是关于如何在 PyTorch 中计算自举交叉熵损失?的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 中的交叉熵损失如何工作?

如何根据 PyTorch 中的概率计算交叉熵?

pytorch nn.CrossEntropyLoss() 中的交叉熵损失

当目标不是单热时,如何计算 Pytorch 中 2 个张量之间的正确交叉熵?

Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解

pytorch交叉熵损失函数 F.cross_entropy()