在 CrossEntropyLoss 和 BCELoss (PyTorch) 中使用权重

Posted

技术标签:

【中文标题】在 CrossEntropyLoss 和 BCELoss (PyTorch) 中使用权重【英文标题】:Using weights in CrossEntropyLoss and BCELoss (PyTorch) 【发布时间】:2021-08-16 04:44:25 【问题描述】:

我正在训练一个 PyTorch 模型来执行二进制分类。我的少数类约占数据的 10%,所以我想使用加权损失函数。 BCELossCrossEntropyLoss 的文档说我可以为每个样本使用 'weight'

但是,当我声明CE_loss = nn.BCELoss()nn.CrossEntropyLoss() 然后执行CE_Loss(output, target, weight=batch_weights),其中outputtargetbatch_weightsTensors 的batch_size,我收到以下错误留言:

forward() got an unexpected keyword argument 'weight'

【问题讨论】:

【参考方案1】:

您是否想对数据集中第 0 类和第 1 类的所有元素应用单独的固定权重?目前尚不清楚您在这里为 batch_weights 传递了什么值。如果是这样,那么这不是 BCELoss 中的权重参数所做的。 weight 参数要求您为数据集中的每个 ELEMENT 传递一个单独的权重,而不是为每个 CLASS 传递一个单独的权重。有几种方法可以解决这个问题。您可以为每个元素构建一个权重表。或者,您可以使用自定义损失函数来满足您的需求:

def BCELoss_class_weighted(weights):

    def loss(input, target):
        input = torch.clamp(input,min=1e-7,max=1-1e-7)
        bce = - weights[1] * target * torch.log(input) - (1 - target) * weights[0] * torch.log(1 - input)
        return torch.mean(bce)

  return loss

请注意,添加钳位以避免数值不稳定很重要。

HTH 杰伦

【讨论】:

【参考方案2】:

实现目标的另一种方法是在初始化损失时使用reduction=none,然后在计算平均值之前将结果张量乘以权重。 例如

loss = torch.nn.BCELoss(reduction='none')
model = torch.sigmoid

weights = torch.rand(10,1)
inputs = torch.rand(10,1)
targets = torch.rand(10,1)

intermediate_losses = loss(model(inputs), targets)
final_loss = torch.mean(weights*intermediate_losses)

当然,对于您的场景,您仍然需要计算权重张量。但希望这会有所帮助!

【讨论】:

【参考方案3】:

问题在于您提供了 weight 参数。正如文档here 中提到的,应该在模块实例化期间提供 weights 参数。

例如,类似,

from torch import nn
weights = torch.FloatTensor([2.0, 1.2]) 
loss = nn.BCELoss(weights=weights)

您可以找到更具体的示例 here 或其他有用的 PT 论坛讨论 here。

【讨论】:

BCELoss 的文档说“权重”应该是“手动重新调整权重,赋予每个批次元素的损失”。如果给定,则必须是大小为 nbatch 的张量。如果每批的权重都发生变化怎么办? 从表面上看,我认为这是不可能的。最重要的是,我认为每批动态调整权重可能会对学习产生负面影响,因为损失函数属性每批都在不断变化。【参考方案4】:

你需要像下面这样传递权重:

CE_loss = CrossEntropyLoss(weight=[…])

【讨论】:

以上是关于在 CrossEntropyLoss 和 BCELoss (PyTorch) 中使用权重的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch详解NLLLoss和CrossEntropyLoss

fastjson-BCEL不出网打法原理分析

1. CrossEntropyLoss() 中的加权损失 2. WeightedRandomSampler 和 subsampler 的组合

pytorch nn.CrossEntropyLoss() 中的交叉熵损失

pytorch 自定义损失函数 nn.CrossEntropyLoss

pytorch 错误:CrossEntropyLoss() 中不支持多目标