如何在 PyTorch 中清除 Cuda 内存
Posted
技术标签:
【中文标题】如何在 PyTorch 中清除 Cuda 内存【英文标题】:How to clear CUDA memory in PyTorch 【发布时间】:2019-08-14 18:53:41 【问题描述】:我正在尝试获取我已经训练过的神经网络的输出。输入是大小为 300x300 的图像。我使用的批量大小为 1,但在成功获得 25 张图像的输出后,我仍然收到 CUDA error: out of memory
错误。
我在网上搜索了一些解决方案,发现了torch.cuda.empty_cache()
。但这似乎仍然不能解决问题。
这是我正在使用的代码。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_x = torch.tensor(train_x, dtype=torch.float32).view(-1, 1, 300, 300)
train_x = train_x.to(device)
dataloader = torch.utils.data.DataLoader(train_x, batch_size=1, shuffle=False)
right = []
for i, left in enumerate(dataloader):
print(i)
temp = model(left).view(-1, 1, 300, 300)
right.append(temp.to('cpu'))
del temp
torch.cuda.empty_cache()
这个for loop
每次运行25次才给出内存错误。
每次,我都会在网络中发送一个新图像进行计算。因此,在循环中的每次迭代之后,我真的不需要将先前的计算结果存储在 GPU 中。有什么方法可以实现吗?
任何帮助将不胜感激。谢谢。
【问题讨论】:
【参考方案1】:我知道我哪里出错了。我将发布解决方案作为其他可能遇到相同问题的人的答案。
基本上,PyTorch 所做的是每当我通过网络传递数据并将计算结果存储在 GPU 内存上时,它都会创建一个计算图,以防我想在反向传播期间计算梯度。但由于我只想执行前向传播,我只需为我的模型指定torch.no_grad()
。
因此,我的代码中的 for 循环可以重写为:
for i, left in enumerate(dataloader):
print(i)
with torch.no_grad():
temp = model(left).view(-1, 1, 300, 300)
right.append(temp.to('cpu'))
del temp
torch.cuda.empty_cache()
为我的模型指定 no_grad()
告诉 PyTorch 我不想存储任何以前的计算,从而释放我的 GPU 空间。
【讨论】:
这很有趣。改变模型的模式(从训练到评估)有帮助吗?我想知道是否有一个内部机制可以自动告诉 pytorch 模式已更改为 eval 所以不需要保存计算?这意味着如果 net.eval() 没有明确告诉 pytorch 在前向传递期间不保存计算,我可以使用“with torch.no_grad()”进行验证和推理? 为了进行推理(只是前向传递),您只需要指定 net.eval() 它将禁用您的 dropout 和 batchnorm 层,将模型置于评估模式。但是,强烈建议也将它与 torch.no_grad() 一起使用,因为它会禁用 autograd 引擎(在推理过程中您可能不希望使用它),这将节省您的时间和内存。只做 net.eval() 仍然会计算梯度,使其变慢并消耗你的内存。 如果我通过 .numpy().cpu() 将数据张量(比如说预测和 groundtruth)发送到 cpu(),我还需要提及“with torch.no_grad()”吗? 如果你的变量有requires_grad=True
,那么你不能直接调用.numpy()。您首先必须执行 .detach() 来告诉 pytorch 您不想计算该变量的梯度。接下来,如果您的变量在 GPU 上,您首先需要将其发送到 CPU 以便使用 .cpu() 转换为 numpy。因此,它将类似于var.detach().cpu().numpy()
。
但是使用 torch.no_grad(),你不需要提及 .detach(),因为无论如何都不会计算梯度。以上是关于如何在 PyTorch 中清除 Cuda 内存的主要内容,如果未能解决你的问题,请参考以下文章
5.51 GiB 已分配; 417.00 MiB 免费; PyTorch CUDA 总共保留 5.53 GiB 内存不足