为啥我不能将交叉熵损失用于多标签?

Posted

技术标签:

【中文标题】为啥我不能将交叉熵损失用于多标签?【英文标题】:Why can't I use Cross Entropy Loss for multilabel?为什么我不能将交叉熵损失用于多标签? 【发布时间】:2021-01-16 04:13:00 【问题描述】:

我正在针对自然问题数据集中的长答案任务对 BERT 模型进行微调。我正在像 SQuAD 模型一样训练模型(预测开始和结束标记)。

我使用 Huggingface 和 PyTorch。

所以目标和标签的形状/大小为 [batch, 2]。我的问题是我无法输入“多目标”,我认为这是指最后一个形状是 2

RuntimeError:/pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:18 不支持多目标

我应该选择其他损失函数还是有其他方法可以绕过这个问题?

我正在使用的这段代码:

def loss_fn(preds, targets):
    return nn.CrossEntropyLoss()(preds,labels)
class DecoderModel(nn.Module):

    def __init__(self, model_args, encoder_config, loss_fn):
        super(DecoderModel, self).__init__()
        # ...

    def forward(self, pooled_output, labels):   
        pooled_output = self.dropout(pooled_output)
        logits = self.linear(pooled_output)

        start_logits, end_logits = logits.split(1, dim = -1)
        start_logit = torch.squeeze(start_logits, axis=-1)
        end_logit = torch.squeeze(end_logits, axis=-1)

        # Concatenate into a "label"
        preds = torch.cat((start_logits, end_logits), -1)

        # Calculate loss
        loss = self.loss_fn(
            preds = preds, 
            labels = labels)

        return loss, preds

目标属性是: torch.int64 & [3,2]

预测属性是: torch.float32 & [3,2]

已解决 - 这是我的解决方案

def loss_fn(preds:list, labels):
    start_token_labels, end_token_labels = labels.split(1, dim = -1)
    start_token_labels = start_token_labels.squeeze(-1)
    end_token_labels = end_token_labels.squeeze(-1)

    print('*'*50)
    print(preds[0].shape) # preds [0] and [1] has the same shape and dtype
    print(preds[0].dtype) # preds [0] and [1] has the same shape and dtype
    print(start_token_labels.shape) # labels [0] and [1] has the same shape and dtype
    print(start_token_labels.dtype) # labels [0] and [1] has the same shape and dtype

    start_loss = nn.CrossEntropyLoss()(preds[0], start_token_labels)
    end_loss = nn.CrossEntropyLoss()(preds[1], end_token_labels)

    avg_loss = (start_loss + end_loss) / 2
    return avg_loss

基本上,我正在拆分 logits(只是不连接它们)和标签。然后我对它们都进行交叉熵损失,最后取两者之间的平均损失。希望这能给您一个解决自己问题的想法!

【问题讨论】:

【参考方案1】:

你不应该给CrossEntropyLoss一个1-hot向量,而是直接给标签

Target: (N) 其中每个值是 0≤targets[i]≤C−1 ,或者 (N, d_1, d_2, ..., d_K) 在 K 维损失的情况下 K≥1。

您可以查看文档来重现您的错误:

>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()

但如果您将target 更改为target = torch.empty((3, 5), dtype=torch.long).random_(5),则会出现错误:

RuntimeError: 一维目标张量,不支持多目标

使用 nn.BCELoss 和 logits 作为输入,请参阅此示例:https://discuss.pytorch.org/t/multi-label-classification-in-pytorch/905/41

>>> nn.BCELoss()(torch.softmax(input, axis=1), torch.softmax(target.float(), axis=1))
>>> tensor(0.6376, grad_fn=<BinaryCrossEntropyBackward>)

【讨论】:

谢谢,这有帮助!我最终做了类似的事情,将 logits 分成 [3,1] 形状的碎片,然后对两个 logits 进行 CE 损失,然后取平均损失。

以上是关于为啥我不能将交叉熵损失用于多标签?的主要内容,如果未能解决你的问题,请参考以下文章

为啥训练多类语义分割的unet模型中的分类交叉熵损失函数非常高?

分类交叉熵和标签编码

为啥tf模型训练时的二元交叉熵损失与sklearn计算的不同?

多标签分类损失函数

多标签分类中的损失函数与评价指标

具有高度不平衡的多标签分类中的损失曲线