使用`checkpoint`进行显存优化的学习笔记
Posted songyuc
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用`checkpoint`进行显存优化的学习笔记相关的知识,希望对你有一定的参考价值。
1 介绍
Checkpoint的主要原理是:在前向阶段传递到checkpoint
中的forward
函数会以 torch.no_grad 模式运行,并且仅仅保存输入参数和 forward 函数,在反向阶段重新计算其 forward 输出值。
(引用于《拿什么拯救我的 4G 显卡 | OpenMMLab》)
2 写作思路
- 只在
nn.Module
的上层模块使用checkpoint,而不是在大模型的forward函数中写作;
3 示例代码
使用checkpoint
的示例代码:
我们可以学习使用checkpoint
进行显存优化,示例代码如下:
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
# x.requires_grad 这个判断很有必要
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
以上是关于使用`checkpoint`进行显存优化的学习笔记的主要内容,如果未能解决你的问题,请参考以下文章
后向重计算在OneFlow中的实现:以时间换空间,大幅降低显存占用
机器学习笔记:优化器Lion(EvoLved Sign Momentum)