显存不足?PyTorch 显存使用分析与优化

Posted 专知

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了显存不足?PyTorch 显存使用分析与优化相关的知识,希望对你有一定的参考价值。

面对动辄几百万几千万参数量的模型, GPU那连常规 U盘都比不过的显存, 真的是杯水车薪。相信大家在日常模型训练过程中,或多或少的总会遇见:

torch.FatalError: cuda runtime error (2) : out of memory 


那么, 可怜的显存到底在哪里消耗掉了?


显存去哪了?


首先我们看一下 PyTorch 官方的数据格式(由于文本只讨论显存,故将 CPU tensor 去掉了):

 

其中,大家日常训练用到的一般是torch.cuda.FloatTensor和torch.cuda.IntTensor, 根据上表,它们分别是32位浮点数和32位整数,都占4个 Byte.

也就是说,常见的一张224x224大小的图片,占用 个 Byte,约为588K。


根据常用的张量组织形式(bathc_size, channel, height, width), 那么一个大小时(128,3, 224, 224) 的张量约占用73.5M。

 

让我们来看一下常见的 VGG-16的内存占用:

显存不足?PyTorch 显存使用分析与优化

上述图片来自 CS231n 的课件 

(http://cs231n.stanford.edu/slides/2017/cs231n_2017_lecture9.pdf)


是 VGG-16 在输入单张图片,且在不计算 biases的情况下计算的内存和参数量占用。


红色字体代表:从输入图像开始,每一个的卷积池化,全连接所产升的内存消耗。

蓝色字体代表:每一步计算,所需要记住的参数的个数


注意,这只是前向 forward 时候的内存消耗,计算反向传播的时候,上述消耗要至少再乘以2,因为链式法则要记很多中间结果。

 

最终的结果有点沉重,单张图片96M,反向传播时单张图片192M。这时候,我们稍微升高一下 batch_size, 比如batch_size = 128, 结果可想而知

 

处理模型设计的固有内存占用,还有很多模块也占用非常多的显存。


比如:优化器


比如常用的随机梯度下降SGD,它的计算公式是:


显存不足?PyTorch 显存使用分析与优化

显然,程序是要记忆显存不足?PyTorch 显存使用分析与优化的,这一点在上面讲过。


但是,如果要使用动量(Momentum-SGD) , 程序还要额外记忆动量, 占用显存 x3:

显存不足?PyTorch 显存使用分析与优化

显存不足?PyTorch 显存使用分析与优化


而更常用的,如果我们使用 Adam 优化器,占用显存 x4。

 

总的来说,显存占用分四个部分:

  • 模型参数

  • 模型计算中间结果

  • 反向传播中间结果

  • 优化器额外参数


计算模型显存占用


比如我们有一个如下的 PyTorch 模型:


Sequential(
(conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1))
(relu_1): ReLU(inplace)
(conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1))
(relu_2): ReLU(inplace)
(pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1,
ceil_mode=False)
(conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1))
)


我们可以使用model.parameters() 将模型的参数拿出来,然后统计一下个数:

import numpy asnp
# 模型内参数个数
para =sum([np.prod(list(p.size())) for p in model.parameters()])
# float32 占 4个 Byte
type_size =4
model_size =para * type_size /1024/1024
print(model_size)
# 约为0.43M


显然,这仅仅是第一部分,模型参数

 

模型计算的中间结果可以通过传入一个简单输入,然后提取中间结果来统计:

input_ =input.clone()  
# 取消计算梯度,我们现在只计算模型中间结果的大小
input_.requires_grad_(requires_grad=False)

mods =list(model.modules())
out_sizes =[]

fori inrange(1, len(mods)):
   m =mods[i]
   ifisinstance(m,nn.ReLU):  
       ifm.inplace:
          continue
   
out =m(input_)
  out_sizes.append(np.array(out.size()))
  input_ = out

total_nums =0
fori inrange(len(out_sizes)):
   s =out_sizes[i]
   nums =np.prod(np.array(s))
  total_nums += nums
   

cost_size =total_nums * type_size /1000/1000
print(cost_size)
# 约为320.5M


正向计算结果约为320.5M,反向x2,约为641M

 

如何优化

出了模型层面的优化,其他的显存优化主要有以下几点:

  • 减少batch,减少每次的输入图像数量

  • 多使用下采样,池化层等

  • 一些神经网络层可以进行小优化,利用relu层中设置inplace


PyTorch 0.4推出了一个新功能,能够在一定程度上解决显存不足的问题,测量主要是拿时间换空间。它的实现方式是将一个计算过程分成很多份,我们就可以先计算一部分,保存后一部分需要的中间结果,然后再计算后一部分。


也就是说,新的checkpoint允许我们只存储反向传播所需要的部分内容。如果当中缺少一个输出(为了节省内存而导致的),checkpoint将会从最近的检查点重新计算中间输出,以便减少显存使用:


比如,我们有一个1000层的全连接:

input=torch.rand(1, 10)
# 1000 个 10 x 10 的全连接层
layers = [nn.Linear(10, 10) for _ inrange(1000)]
model = nn.Sequential(*layers)
output = model(input)


这个模型如果正常运行将占用海量显存,因为计算中会产生很多的中间变量。现在我们可以使用checkpoint节省资源占用。

input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]

# 计算第一部分
def run_a(*args):
   x = args[0]
   for layer in layers[:500]:
       x = layer(x)
   return x

# 计算第二部分
def run_b(*args):
   x = args[0]
   for layer in layers[500:-1]:
       x = layer(x)
   return x

from torch.utils.checkpoint import checkpoint
result_a = checkpoint(run_a, input)
result_b = checkpoint(run_b, result_a)
total_result = layers[-1](result_b)
output = total_result.sum()
output.backward()


-END-

专 · 知


加入专知人工智能服务群: 欢迎微信扫一扫加入专知人工智能知识星球群,获取专业知识教程视频资料和与专家交流咨询!

显存不足?PyTorch 显存使用分析与优化


请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料

请加专知小助手微信(扫一扫如下二维码添加),加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)交流~

点击“阅读原文”,使用专知

以上是关于显存不足?PyTorch 显存使用分析与优化的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch]深度模型的显存计算以及优化

把显存用在刀刃上!17 种 pytorch 节约显存技巧

PyTorch显存机制分析

深入理解卷积与模型大小问题,解决显存不足

Facebook推出8比特优化器,可以节省75%的显存

WebGL与Canvas的显存与内存使用分析