受限显存下增加batchsize策略:gradient checkpointing
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了受限显存下增加batchsize策略:gradient checkpointing相关的知识,希望对你有一定的参考价值。
参考技术Ahttps://blog.csdn.net/lavinia_chen007/article/details/113609838
我是在 Swin-Transformer的开源 里找到的:
When GPU memory is not enough, you can try the following suggestions:
Use gradient accumulation by adding --accumulation-steps <steps>, set appropriate <steps> according to your need.
Use gradient checkpointing by adding --use-checkpoint, e.g., it saves about 60% memory when training Swin-B. Please refer to this page for more details.
We recommend using multi-node with more GPUs for training very large models, a tutorial can be found in this page .
看看PyTorch官网怎么说:
https://pytorch.org/docs/stable/checkpoint.html
注意:
Checkpointing是通过在反向传播过程中为每个Checkpointed段重新运行前向传播分段来实现的。这可能会导致像RNG状态这样的持久状态比没有Checkpointing的状态更高级。默认情况下,Checkpointing包括改变RNG状态的逻辑,这样,与非Checkpointed过程相比,使用RNG的Checkpointing过程(例如通过dropout)具有确定性输出。根据Checkpointing操作的运行时间,隐藏和恢复RNG状态的逻辑可能会导致适度的性能损失(moderate performance hit)。如果不需要与非Checkpointing过程相比较的确定性输出,则向Checkpointing或Checkpointing顺序提供preserve_rng_state=False,以在每个Checkpointing期间省略存储和恢复rng状态。
存储逻辑将当前设备和所有cuda张量参数的设备的RNG状态保存并恢复到run_fn。但是,逻辑无法预测用户是否会将张量移动到run_fn自身内的新设备。因此,如果在run_fn中将张量移动到新设备(“new”表示不属于[当前设备+张量参数的设备]),则与非Checkpointing过程相比,决不能保证确定性输出。
Checkpointing模型或模型的一部分
Checkpointing的工作原理是用计算换取内存。Checkpointing部分不会存储整个计算图的所有中间激活以进行反向计算,不会保存中间激活,而是在反向过程中重新计算它们。它可以应用于模型的任何部分。
具体来说,在向前传递中,函数将以torch.no_grad()方式运行,即不存储中间激活。相反,向前传递保存输入元组和函数参数。在向后传递中,检索保存的输入和函数,并再次在函数上计算向前传递,现在跟踪中间激活,然后使用这些激活值计算梯度。
函数的输出可以包含非张量值,仅对张量值执行梯度记录。请注意,如果输出包含由张量组成的嵌套结构(例如:自定义对象、列表、dict等),则嵌套在自定义结构中的这些张量将不会被视为autograd的一部分。
警告:
Checkpointing当前仅支持torch.autograd.backward(),并且仅当其输入参数未传递时才支持。不支持torch.autograd.grad()。
警告:
如果向后期间的函数调用与向前期间的函数调用不同,例如,由于某些全局变量,Checkpointing版本将不等效,不幸的是,它无法被检测到。
警告:
如果Checkpointing段包含由detach()或torch.no_grad()从计算图中分离的张量,则向后传递将引发错误。这是因为Checkpointing使得所有的输出都需要梯度,当张量被定义为在模型中没有梯度时,这会导致问题。要避免这种情况,请分离Checkpointing函数外部的张量。
警告:
如果模型输入需要梯度,则至少有一个输入需要具有requires_grad=True,否则模型的Checkpointing部分将不具有梯度。至少有一个输出需要同时具有requires_grad=True。
输入参数列表:
function –描述在模型或部分模型的正向传递中运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过(激活,隐藏),函数应正确使用第一个输入作为激活,第二个输入作为隐藏
preserve_rng_state–(bool,可选,默认值=True)–在每个Checkpointing期间省略存储和恢复rng状态。
args–包含函数输入的元组
返回:
*args上运行函数的输出
顺序模型按顺序(顺序)执行模块/功能列表。因此,我们可以将这样一个模型划分为不同的部分,并检查每个部分。除最后一段以外的所有段都将以torch.no_grad()方式运行,即不存储中间激活。将保存每个Checkpointing段的输入,以便在反向过程中重新运行该段。
请参阅checkpoint(),了解Checkpointing的工作原理。
警告
Checkpointing当前仅支持torch.autograd.backward(),并且仅当其输入参数未传递时才支持。不支持torch.autograd.grad()。
参数
functions –torch.nn.Sequential或要按顺序运行的模块或功能(包括模型)列表。
segments–要在模型中创建的块数
input –输入到函数的张量
preserve_rng_state(bool,可选,默认值=True)–在每个Checkpointing期间省略存储和恢复rng状态。
返回:
按*输入顺序输出运行函数
实例
上面那个博主的示例:
注意第94行,必须确保checkpoint的输入输出都声明为require_grad=True的Variable,否则运行时会报如下的错
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
再看看这位老哥的, https://blog.csdn.net/weixin_43002433/article/details/105322846
再再看看这位的: https://blog.csdn.net/ONE_SIX_MIX/article/details/93937091
以上是关于受限显存下增加batchsize策略:gradient checkpointing的主要内容,如果未能解决你的问题,请参考以下文章