如何在 autograd 反向传播中禁用某些模块的梯度更新?

Posted

技术标签:

【中文标题】如何在 autograd 反向传播中禁用某些模块的梯度更新?【英文标题】:How can I disable gradient updates for some modules in autograd backpropagation? 【发布时间】:2020-01-16 14:42:03 【问题描述】:

我正在构建一个用于强化学习的多模型神经网络,其中包括一个动作网络、一个世界模型网络和一个批评者。这个想法是训练世界模型以根据来自动作网络的输入和先前状态来模拟您试图掌握的任何模拟,训练评论家以根据世界模型输出最大化贝尔曼方程(随时间的总强化),然后通过世界模型反向传播critic值,为训练动作提供梯度目标。所以 - 从某个状态,动作网络输出一个动作,该动作被馈送到模型中以生成下一个状态,并且该状态馈送到评论家网络以针对某个目标状态进行评估。

为了使所有这些工作,我必须使用 3 个单独的损失函数,每个网络一个,它们都在一个或多个网络的梯度中添加一些东西,但它们可能会发生冲突。例如 - 为了训练世界模型,我使用来自环境模拟的目标,而对于评论家,我使用当前状态奖励 + 折扣 * 下一个状态预测值的目标。然而,为了训练一个演员,我只是使用负面评论家值作为损失,并在所有三个模型中一直反向传播以校准最佳动作。

我可以通过逐步将梯度归零来完成这项工作,而无需任何批处理,但这是低效的,并且不允许我为任何类型的“时间序列批处理”优化器更新步骤累积梯度。每个模型都有自己的可训练参数,但执行图流经所有三个网络。因此,在依次触发网络后的校准循环中: ...

        if self.actor.calibrating:
            self.actor.optimizer.zero_grad()
            #Pick loss For maximizing the value of all actions
            loss = -self.critic.value
            #Backpropagate through all three networks to train actor output
            #How do I stop the critic and model networks from incrementing their gradient values?
            loss.backward(retain_graph=True)
            self.actor.optimizer.step()
        if self.model.calibrating:
            self.model.optimizer.zero_grad()
            #Reduce loss for ambiguous actions
            loss = self.model.get_loss() * self.actor.get_confidence()**2
            #How can I block this from backpropagating through action network?
            loss.backward(retain_graph=True)
            self.model.optimizer.step()
        if self.critic.calibrating:
            self.critic.optimizer.zero_grad()
            #Reduce loss for ambiguous actions
            loss = self.critic.get_loss(self.goal) * self.actor.get_confidence()**2
            #How do I stop this from backpropagating through the model and action networks?
            loss.backward(retain_graph=True)
            self.critic.optimizer.step()

...

最后 - 我的问题分为两部分:

    如何在给定层暂时停止 loss.backward() 而不会永远分离它? 如何阻止 loss.backward() 更新一些渐变,而我只是流经一个模型以获取另一个模型的渐变?

【问题讨论】:

我的一位同事给了我一些似乎有效的见解。如果一切顺利,我会在明天发布。 【参考方案1】:

感谢一位同事的建议,尝试使用 requires_grad 设置。 (我曾假设这会破坏执行图,但事实并非如此)

所以 - 回答我自己的两个问题:

    如果您以正确的顺序校准链接模型,您可以一次分离一个,这样 loss.backward() 就不会超出不需要的模型。我在想这会破坏图表,但是......这是 Pytorch,而不是 Tensorflow 1.x,无论如何,图表都会在每次前向传递时重新生成。昨天错过了这个,我真傻。 如果您将模型(或层或单个权重)的 requires_grad 设置为 False,则 loss.backward() 仍将遍历整个连接图,但它会保留这些单个梯度,同时仍会较早设置任何梯度在图中。正是我想要的。

此代码可最大限度地减少不必要的图形遍历和梯度更新的执行。我仍然需要重构它以随着时间的推移进行交错更新,以便它可以在步进优化器之前累积几个周期的梯度,但这绝对可以按预期工作。

#Step through all models in a chain to create gradient paths from critic back through the world model, to the actor.
    def step(self):
        #Get the current state from the simulation
        state = self.world.state
        #Fire the actor to select a softmax action.
        self.actor(state)
        #run the world simulation on that action.
        self.world.step(self.actor.action)
        #Combine the action and starting state as input to the world model.
        if self.actor.calibrating:
            action_state = torch.cat([self.actor.value, state], dim=0)
        else:
            #Push softmax action closer to 1.0
            action_state = torch.cat([self.actor.hard_value, state], dim=0)
        #Run the model and then the critic on the action_state
        self.critic(self.model(action_state))
        if self.actor.calibrating:
            self.actor.optimizer.zero_grad()
            self.model.requires_grad = False
            self.critic.requires_grad = False
            #Pick loss For maximizing the value of the action choice
            loss = -self.critic.value * self.actor.get_confidence()
            loss.backward(retain_graph=True)
            self.actor.optimizer.step()
        if self.model.calibrating:
            #Don't need to backpropagate through actor again
            self.actor.value.detach_()
            self.model.optimizer.zero_grad()
            self.model.requires_grad = True
            #Reduce loss for ambiguous actions
            loss = self.model.get_loss() * self.actor.get_confidence()**2
            loss.backward(retain_graph=True)
            self.model.optimizer.step()
        if self.critic.calibrating:
            #Don't need to backpropagate through the model or actor again
            self.model.value.detach_()
            self.critic.optimizer.zero_grad()
            self.critic.requires_grad = True
            #Reduce loss for ambiguous actions
            loss = self.critic.get_loss(self.goal) * self.actor.get_confidence()**2
            loss.backward(retain_graph=True)
            self.critic.optimizer.step()

【讨论】:

以上是关于如何在 autograd 反向传播中禁用某些模块的梯度更新?的主要内容,如果未能解决你的问题,请参考以下文章

Autograd:自动微分

动手深度学习15-深度学习-正向传播反向传播和计算图

pytorch学习笔记第三篇———自动梯度(torch.autograd)

Torch反向传播时出错或者梯度为NaN的问题排查

深扒torch.autograd原理

pytorch中的Variable()