使用`checkpoint`进行显存优化的学习笔记

Posted songyuc

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用`checkpoint`进行显存优化的学习笔记相关的知识,希望对你有一定的参考价值。

1 介绍

Checkpoint的主要原理是:在前向阶段传递到checkpoint中的forward函数会以 torch.no_grad 模式运行,并且仅仅保存输入参数和 forward 函数,在反向阶段重新计算其 forward 输出值。
(引用于《拿什么拯救我的 4G 显卡 | OpenMMLab》

2 写作思路

  • 只在nn.Module的上层模块使用checkpoint,而不是在大模型的forward函数中写作;

3 示例代码

使用checkpoint的示例代码:

我们可以学习使用checkpoint进行显存优化,示例代码如下:

def forward(self, x):
    def _inner_forward(x):
        identity = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out
        
    # x.requires_grad 这个判断很有必要
    if self.with_cp and x.requires_grad:
        out = cp.checkpoint(_inner_forward, x)
    else:
        out = _inner_forward(x)
    out = self.relu(out)
    return out

以上是关于使用`checkpoint`进行显存优化的学习笔记的主要内容,如果未能解决你的问题,请参考以下文章

后向重计算在OneFlow中的实现:以时间换空间,大幅降低显存占用

机器学习笔记:优化器Lion(EvoLved Sign Momentum)

受限显存下增加batchsize策略:gradient checkpointing

深度学习分布式策略优化显存优化通信优化编译优化综述

深度学习分布式策略优化显存优化通信优化编译优化综述

深度学习分布式策略优化显存优化通信优化编译优化综述