pytorch torch.no_grad()函数(禁用梯度计算)(当确保下文不用backward()函数计算梯度时可以用,用于禁用梯度计算功能,以加快计算速度)

Posted Dontla

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch torch.no_grad()函数(禁用梯度计算)(当确保下文不用backward()函数计算梯度时可以用,用于禁用梯度计算功能,以加快计算速度)相关的知识,希望对你有一定的参考价值。

Context-manager that disabled gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure that you will not call Tensor.backward(). It will reduce memory consumption for computations that would otherwise have requires_grad=True.
In this mode, the result of every computation will have requires_grad=False, even when the inputs have requires_grad=True.
This context manager is thread local; it will not affect computation in other threads.
Also functions as a decorator. (Make sure to instantiate with parenthesis.

Note
No-grad is one of several mechanisms that can enable or disable gradients locally see locally-disable-grad-doc for more information on how they compare.

文档复制自:
no_grad
  < Python 3.8 >

解释

禁用梯度计算的上下文管理器。
当您确定不会调用Tensor.backward()时,禁用梯度计算对于推理非常有用。它将减少原本需要_grad=True的计算的内存消耗。
在此模式下,每次计算的结果都将具有requires_grad=False,即使输入具有requires_grad=True。
这个上下文管理器是线程本地的;它不会影响其他线程中的计算。
也可以作为装饰(确保用括号实例化)

示例1

Example:
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False

@表示python中的装饰,没学过不知道用法@//??啥时候研究下

示例2

# 定义优化算法
def sgd(params, lr, batch_size):  # @save
    """⼩批量随机梯度下降。"""
    with torch.no_grad():   # 禁用梯度计算以加快计算速度
        for param in params:
            param -= lr * param.grad / batch_size   # 为什么要除以batch_size?
            # 因为我们计算的损失是⼀个批量样本的总和,
            # 所以我们⽤批量⼤小(batch_size)来归⼀化步⻓,这样步⻓⼤小就不会取决于我们对批量⼤小的选择。
            param.grad.zero_()  # 梯度清零

以上是关于pytorch torch.no_grad()函数(禁用梯度计算)(当确保下文不用backward()函数计算梯度时可以用,用于禁用梯度计算功能,以加快计算速度)的主要内容,如果未能解决你的问题,请参考以下文章

【Pytorch】model.eval() vs torch.no_grad()

pytorch中的train.eval() 与 with torch.no_grad()的使用

pytorch中model.eval()和torch.no_grad()的区别

pytorch中model.eval()和torch.no_grad()的区别

与torch.no_grad:AttributeError:__enter__

Torch.no_grad()影响MSE损失