如何在pytorch中计算BCEWithLogitsLoss的不平衡权重
Posted
技术标签:
【中文标题】如何在pytorch中计算BCEWithLogitsLoss的不平衡权重【英文标题】:How to calculate unbalanced weights for BCEWithLogitsLoss in pytorch 【发布时间】:2019-11-23 02:04:30 【问题描述】:我正在尝试使用270
标签解决一个多标签问题,并且我已将目标标签转换为一种热编码形式。我正在使用BCEWithLogitsLoss()
。由于训练数据不平衡,我使用pos_weight
参数,但我有点困惑。
pos_weight
(Tensor, optional) – 正样本的权重。必须是长度等于类数的向量。
我是否需要将每个标签的正值总数作为张量或它们的权重表示其他含义?
【问题讨论】:
你可以在这里查看讨论:discuss.pytorch.org/t/… 【参考方案1】:PyTorch documentation for BCEWithLogitsLoss 建议将 pos_weight 设为每个类的负数和正数之间的比率。
因此,如果len(dataset)
为 1000,则 multihot 编码的元素 0 有 100 个正数,那么 pos_weights_vector
的元素 0 应该是 900/100 = 9
。这意味着二元交叉损失将表现为数据集包含 900 个正例而不是 100 个。
这是我的实现:
def calculate_pos_weights(class_counts):
pos_weights = np.ones_like(class_counts)
neg_counts = [len(data)-pos_count for pos_count in class_counts]
for cdx, pos_count, neg_count in enumerate(zip(class_counts, neg_counts)):
pos_weights[cdx] = neg_count / (pos_count + 1e-5)
return torch.as_tensor(pos_weights, dtype=torch.float)
其中class_counts
只是正样本的逐列总和。我在 PyTorch 论坛上 posted it 并且其中一位 PyTorch 开发人员给予了他的祝福。
【讨论】:
你能说清楚退货的重量吗?是 pos_weights 还是任何其他权重。如果是不同的东西,你能提供更多的细节吗? @ParadiN 抱歉,我的变量不清楚。我在回答中澄清了变量名称。 另一种方法:从数据标签获取class_counts(示例标签:[[0],[0],[1]...])classes, class_counts = np.unique(labels, return_counts=True)
。计算重量:weight = [(class_counts.sum() - x)/ (x+1e-5) for i, x in enumerate(class_counts)]
【参考方案2】:
PyTorch 解决方案
嗯,实际上我已经浏览了文档,您确实可以简单地使用 pos_weight
。
这个参数赋予每个类的正样本权重,因此如果你有 270
类,你应该传递 torch.Tensor
形状 (270,)
定义每个类的权重。
这里是从documentation略微修改的sn-p:
# 270 classes, batch size = 64
target = torch.ones([64, 270], dtype=torch.float32)
# Logits outputted from your network, no activation
output = torch.full([64, 270], 0.9)
# Weights, each being equal to one. You can input your own here.
pos_weight = torch.ones([270])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(0.9))
自制解决方案
在权重方面,没有内置的解决方案,但您可以很容易地自己编写一个解决方案:
import torch
class WeightedMultilabel(torch.nn.Module):
def __init__(self, weights: torch.Tensor):
self.loss = torch.nn.BCEWithLogitsLoss()
self.weights = weights.unsqueeze()
def forward(outputs, targets):
return self.loss(outputs, targets) * self.weights
Tensor
的长度必须与多标签分类中的类别数 (270) 相同,每个类别都为您的具体示例赋予权重。
计算权重
您只需添加数据集中每个样本的标签,除以最小值并在最后取反。
sn-p 之类的:
weights = torch.zeros_like(dataset[0])
for element in dataset:
weights += element
weights = 1 / (weights / torch.min(weights))
使用这种方法,出现最少的类会产生正常损失,而其他类的权重会小于1
。
但它可能会在训练期间导致一些不稳定,因此您可能想稍微尝试一下这些值(也许log
变换而不是线性?)
其他方法
您可能会考虑上采样/下采样(尽管此操作很复杂,因为您还会添加/删除其他类,因此我认为需要高级启发式)。
【讨论】:
【参考方案3】:只是为了快速修改@crypdick 的答案,这个函数的实现对我有用:
def calculate_pos_weights(class_counts,data):
pos_weights = np.ones_like(class_counts)
neg_counts = [len(data)-pos_count for pos_count in class_counts]
for cdx, (pos_count, neg_count) in enumerate(zip(class_counts, neg_counts)):
pos_weights[cdx] = neg_count / (pos_count + 1e-5)
return torch.as_tensor(pos_weights, dtype=torch.float)
data
是您尝试对其应用权重的数据集。
【讨论】:
【参考方案4】:也许有点晚了,但我是这样计算的:
num_positives = torch.sum(train_dataset.data.y, dim=0)
num_negatives = len(train_dataset.data.y) - num_positives
pos_weight = num_negatives / num_positives
那么权重可以很容易地用作:
criterion = torch.nn.BCEWithLogitsLoss(pos_weight = pos_weight)
【讨论】:
以上是关于如何在pytorch中计算BCEWithLogitsLoss的不平衡权重的主要内容,如果未能解决你的问题,请参考以下文章
如何计算 PyTorch 中注意力分数和编码器输出的加权平均值?
当目标不是单热时,如何计算 Pytorch 中 2 个张量之间的正确交叉熵?
详解pytorch中的交叉熵损失函数nn.BCELoss()nn.BCELossWithLogits(),二分类任务如何定义损失函数,如何计算准确率如何预测