修改后的 PyTorch 损失函数 BCEWithLogitsLoss 返回 NaNs
Posted
技术标签:
【中文标题】修改后的 PyTorch 损失函数 BCEWithLogitsLoss 返回 NaNs【英文标题】:Modified PyTorch loss function BCEWithLogitsLoss returns NaNs 【发布时间】:2020-10-20 10:49:50 【问题描述】:我正在尝试解决二进制分类问题(target=0
和 target=1
),但有一个例外:
我的一些标签被故意归类为 target=0.5
,我希望 零损失 将其分类为 0 或 1(即两个类都是“正确的”)。
我尝试根据 PyTorch 的 BCEWithLogitsLoss 从头开始实现自定义损失:
class myLoss(torch.nn.Module):
def __init__(self, pos_weight=1):
super().__init__()
self.pos_weight = pos_weight
def forward(self, input, target):
epsilon = 10 ** -44
my_bce_loss = -1 * (self.pos_weight * target * F.logsigmoid(input + epsilon)
+ (1 - target) * log(1 - sigmoid(input) + epsilon))
add_loss = (target - 0.5) ** 2 * 4
mean_loss = (my_bce_loss * add_loss).mean()
return mean_loss
epsilon
被选中,因此日志将限制为 -100,正如 BCE loss 中所建议的那样。
但是我仍然遇到 NaN 错误,经过几个时期:
Function 'LogBackward' returned nan values in its 0th output.
或
Function 'SigmoidBackward' returned nan values in its 0th output.
有什么建议可以纠正我的损失函数吗?也许通过某种方式继承和修改forward
函数?
更新: 我调用自定义损失函数的方式:
y = batch[:, -1, :].to(self.device, dtype=torch.float32)
y_pred_batch = self.model(x)
LossFun = myLoss(self.pos_weight)
batch_result.loss = LossFun.forward(y_pred_batch, y)
我用Temporal Convolutional Network model,实现如下:
out = self.conv1(x)
out = self.chomp1(out)
out = self.elu(out)
out = self.dropout1(out)
res = x if self.downsample is None else self.downsample(x)
return self.tanh(out + res)
【问题讨论】:
【参考方案1】:试试这个方法:
class myLoss(torch.nn.Module):
def __init__(self, pos_weight=1):
super().__init__()
self.pos_weight = pos_weight
def forward(self, input, target):
epsilon = 10 ** -44
input = input.sigmoid().clamp(epsilon, 1 - epsilon)
my_bce_loss = -1 * (self.pos_weight * target * torch.log(input)
+ (1 - target) * torch.log(1 - input))
add_loss = (target - 0.5) ** 2 * 4
mean_loss = (my_bce_loss * add_loss).mean()
return mean_loss
为了测试我向后执行 1000:
target = torch.randint(high=2, size=(32,))
loss_fn = myLoss()
for i in range(1000):
inp = torch.rand(1, 32, requires_grad=True)
loss = loss_fn(inp, target)
loss.backward()
if torch.isnan(loss):
print('Loss NaN')
if torch.isnan(inp.grad).any():
print('NaN')
一切正常。
【讨论】:
我试试看! 更新:文件“.../training.py”,第 511 行,向前 + (1 - 目标) * torch.log(1 - 输入)) 函数“LogBackward”返回 nan 值它的第 0 个输出。可悲的是,错误仍然存在 能否提供调用损失函数的代码?以上是关于修改后的 PyTorch 损失函数 BCEWithLogitsLoss 返回 NaNs的主要内容,如果未能解决你的问题,请参考以下文章