在 pytorch 中使用多个损失函数
Posted
技术标签:
【中文标题】在 pytorch 中使用多个损失函数【英文标题】:Using multiple loss functions in pytorch 【发布时间】:2021-02-20 14:57:12 【问题描述】:我正在处理图像恢复任务,我考虑了多个损失函数。我的计划是考虑 3 条路线:
1:使用多个损失进行监控,但仅使用少数损失进行自身训练 2:在用于训练的那些损失函数中,我需要给每个函数一个权重——目前我正在指定权重。我想让那个参数自适应。 3:如果在训练之间 - 如果我观察到饱和,我想改变损失函数。或其组件。目前我考虑重新训练一个网络(如果在第一次训练中模型饱和),使其在第一个 M 时期使用特定的损失函数进行训练,然后我改变损失。
除了最后一种情况,我开发了一个计算这些损失的代码,但我不确定它是否会起作用。 - 即它是否会反向传播? (代码如下)
在使用损失函数组合时是否可以自适应地给出权重 - 即我们可以训练网络以便也学习这些权重吗?
这个实现可以用于上面提到的改变损失函数的案例 3
抱歉,如果这里给出的任何内容不明确或错误。如果我必须改进这个问题,请告诉我。 (我对 PyTorch 有点陌生)
criterion = _criterion
#--training
prediction = model(input)
loss = criterion(prediction, target)
loss.backward()
class _criterion(nn.Module):
def __init__(self, model_type="CNN"):
super(_criterion).__init__()
self.model_type = model_type
def forward(self, pred, ref):
loss_1 = lambda x,y : nn.MSELoss(size_average=False)(x,y)
loss_2 = lambda x,y : nn.L1Loss(size_average=False)(x,y)
loss_3 = lambda x,y : nn.SmoothL1Loss(size_average=False)(x,y)
loss_4 = lambda x,y : L1_Charbonnier_loss_()(x,y) #user-defined
if opt.loss_function_order == 1:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss = lambda x,y: 1*loss_function_1(x,y)
elif opt.loss_function_order == 2:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss_function_2 = get_loss_function(opt.loss_function_2)
weight_1 = opt.loss_function_1_weight
weight_2 = opt.loss_function_2_weight
loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y)
elif opt.loss_function_order == 3:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss_function_2 = get_loss_function(opt.loss_function_2)
loss_function_3 = get_loss_function(opt.loss_function_3)
weight_1 = opt.loss_function_1_weight
weight_2 = opt.loss_function_2_weight
weight_3 = opt.loss_function_3_weight
loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) + weight_3*loss_function_3(x,y)
elif opt.loss_function_order == 4:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss_function_2 = get_loss_function(opt.loss_function_2)
loss_function_3 = get_loss_function(opt.loss_function_3)
loss_function_4 = get_loss_function(opt.loss_function_4)
weight_1 = opt.loss_function_1_weight
weight_2 = opt.loss_function_2_weight
weight_3 = opt.loss_function_3_weight
weight_4 = opt.loss_function_4_weight
loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) + weight_3*loss_function_3(x,y) + weight_4*loss_function_4(x,y)
else:
raise Exception("_criterion : unable to interpret loss_function_order")
return loss(ref,pred), loss_1(ref,pred), loss_2(ref,pred), loss_3(ref,pred), loss_4(ref,pred)
def get_loss_function(loss):
if loss == "MSE":
criterion = nn.MSELoss(size_average=False)
elif loss == "MAE":
criterion = nn.L1Loss(size_average=False)
elif loss == "Smooth-L1":
criterion = nn.SmoothL1Loss(size_average=False)
elif loss == "Charbonnier":
criterion = L1_Charbonnier_loss_()
else:
raise Exception("not implemented")
return criterion
class L1_Charbonnier_loss_(nn.Module):
def __init__(self):
super(L1_Charbonnier_loss_, self).__init__()
self.eps = 1e-6
def forward(self, X, Y):
diff = torch.add(X, -Y)
error = self.eps*((torch.sqrt(1+((diff * diff)/self.eps)))-1)
loss = torch.sum(error)
return loss
【问题讨论】:
【参考方案1】:而我理解你的问题。您的误差计算函数将进行反向传播,但在使用误差函数时需要小心,因为它们在每种情况下的工作方式不同。 关于权重,您需要保存此网络的权重,然后使用 pytorch 的迁移学习将其再次加载到另一个网络中,以便您可以使用其他执行的权重。Here 是有关如何使用的链接从 pythorch 学习迁移。
【讨论】:
以上是关于在 pytorch 中使用多个损失函数的主要内容,如果未能解决你的问题,请参考以下文章