深度学习中如何平衡多个loss?多任务学习自动调整loss weight解决方案

Posted 沉迷单车的追风少年

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习中如何平衡多个loss?多任务学习自动调整loss weight解决方案相关的知识,希望对你有一定的参考价值。

 目录

问题引入

手动调参

自动调参

从每个任务的同方差不确定性出发

论文赏析

代码分析

视频讲解

说明:写在最后


问题引入

深度学习中,针对loss权重的优化是重要的改进方向,许多深度学习应用都受益于具有多重回归和分类目标的多任务学习。每年的顶会都会出现不少关于loss优化的文章,还有大量的新loss定义方式,眼花缭乱。因此,一个深度学习任务中,一个loss往往有由多个loss复合而成:loss = a*loss1+b*loss2+c*loss3...

如何调整这些loss权重?超参数a、b、c……之间经过怎样的组合才是最优?

手动调参

都说DL工程师是炼丹工程师,有经验的工程师会有很强的参数嗅觉,从直觉上判断哪些 可行、哪些不可行。一个有效的方法是从单独loss的特性和对整体的影响效果出发,酌情调参。如:loss = a*loss1+b*loss2+c*loss3,要求a+b+c = 1。

手动调参的缺点非常明显,困难、时间昂贵、可解释性差、经验要求高等等。

自动调参

这是本文的重点。

多任务学习的目的是通过从共享表示中学习多个目标来提高学习效率和预测精度。从计算机视觉到自然语言处理再到语音识别,多任务学习在机器学习的许多应用中都很普遍。

从每个任务的同方差不确定性出发

许多深度学习应用都受益于具有多重回归和分类目标的多任务学习。在本文中,我们观察到这种系统的性能强烈依赖于每个任务损失之间的相对权重。手工调优这些权重是一个困难且昂贵的过程,使得多任务学习在实践中难以实现。我们提出了一种有原则的多任务深度学习方法,通过考虑每个任务的同方差不确定性来权衡多个损失函数。这使得我们能够同时在分类和回归设置中学习不同单位或尺度的不同数量。我们演示了我们的模型学习每像素深度回归,语义和实例分割从单目输入图像。也许令人惊讶的是,我们证明了我们的模型可以学习多任务权重,并且比针对每个任务单独训练的单独模型表现得更好。

在这项工作中,我们提出了一个原则的方法,结合多个损失函数,同时学习多个目标利用同方差不确定性。我们将同方差不确定性解释为任务相关的加权,并演示如何推导有原则的多任务损失函数,该函数可以学习平衡各种回归和分类损失。我们的方法可以学会平衡这些权重,从而获得更好的性能,而不是单独学习每个任务。

具体来说,我们演示了我们的方法在学习场景几何和语义的三个任务。首先,我们学习在像素级对对象进行分类,也称为语义分割。其次,我们的模型执行实例分割,这是一项较困难的任务,即为图像中每个单独的目标分割单独的遮罩(例如,为道路上的每一辆汽车分割单独的、精确的遮罩)。这是比语义分割更加困难的任务,不仅需要估计每个像素的类别,还需要估计像素所属的对象。它也比物体检测更复杂,物体检测通常只预测物体边界框。最后,我们的模型预测像素的度量深度。深度识别已经通过使用有监督和无监督深度学习的密集预测网络演示。然而,很难用一种一般化的方式来估计深度。我们表明,通过使用语义标签和多任务深度学习,我们可以改进几何和深度的估计

论文赏析

2018年的CVPR,地址在此:

https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf

代码分析

先创建文件:AutomaticWeightedLoss.py

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn

class AutomaticWeightedLoss(nn.Module):
    """automatically weighted multi-task loss
    Params:
        num: int,the number of loss
        x: multi-task loss
    Examples:
        loss1=1
        loss2=2
        awl = AutomaticWeightedLoss(2)
        loss_sum = awl(loss1, loss2)
    """
    def __init__(self, num=2):
        super(AutomaticWeightedLoss, self).__init__()
        params = torch.ones(num, requires_grad=True)
        self.params = torch.nn.Parameter(params)

    def forward(self, *x):
        loss_sum = 0
        for i, loss in enumerate(x):
            loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
        return loss_sum

if __name__ == '__main__':
    awl = AutomaticWeightedLoss(2)
    print(awl.parameters())

在同一文件夹下:

git clone git@github.com:Mikoto10032/AutomaticWeightedLoss.git

在自己代码中:

from AutomaticWeightedLoss import AutomaticWeightedLoss

awl = AutomaticWeightedLoss(2)	# we have 2 losses
loss1 = (your loss 1)
loss2 = (your loss 2)
loss_sum = awl(loss1, loss2)

计算优化:

from torch import optim

model = Model()
optimizer = optim.Adam([
                {'params': model.parameters()},
                {'params': awl.parameters(), 'weight_decay': 0}	
            ])

demo:

from torch import optim
from AutomaticWeightedLoss import AutomaticWeightedLoss

model = Model()

awl = AutomaticWeightedLoss(2)	# we have 2 losses
loss_1 = ...
loss_2 = ...

# learnable parameters
optimizer = optim.Adam([
                {'params': model.parameters()},
                {'params': awl.parameters(), 'weight_decay': 0}
            ])

for i in range(epoch):
    for data, label1, label2 in data_loader:
        # forward
        pred1, pred2 = Model(data)	
        # calculate losses
        loss1 = loss_1(pred1, label1)
        loss2 = loss_2(pred2, label2)
        # weigh losses
        loss_sum = awl(loss1, loss2)
        # backward
        optimizer.zero_grad()
        loss_sum.backward()
        optimizer.step()

视频讲解

https://www.youtube.com/watch?v=zVYY9HaEJnc&list=PL_bDvITUYucCIT8iNGW8zCXeY5_u6hg-y&index=2&t=0s

说明:写在最后

虽然mult-task learning很火,使用上也非常简单,但是是否能起作用还是玄学。总之,多尝试,没有坏处,万一有用呢?

以上是关于深度学习中如何平衡多个loss?多任务学习自动调整loss weight解决方案的主要内容,如果未能解决你的问题,请参考以下文章

多任务学习——ICML 2018GradNorm

多任务学习——ICML 2018GradNorm

slowfast 损失函数改进深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

slowfast 损失函数改进深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

吴恩达-医学图像人工智能专项课程-第一课第一周13-15节-迁移学习+数据增强

深度学习:batch_size和学习率 及如何调整