PyTorch 中“detach()”和“with torch.nograd()”的区别?

Posted

技术标签:

【中文标题】PyTorch 中“detach()”和“with torch.nograd()”的区别?【英文标题】:Difference between "detach()" and "with torch.nograd()" in PyTorch? 【发布时间】:2019-11-10 23:07:17 【问题描述】:

我知道从梯度计算中排除计算元素的两种方法backward

方法一:使用with torch.no_grad()

with torch.no_grad():
    y = reward + gamma * torch.max(net.forward(x))
loss = criterion(net.forward(torch.from_numpy(o)), y)
loss.backward();

方法二:使用.detach()

y = reward + gamma * torch.max(net.forward(x))
loss = criterion(net.forward(torch.from_numpy(o)), y.detach())
loss.backward();

这两者有区别吗?两者都有优点/缺点吗?

【问题讨论】:

【参考方案1】:

tensor.detach() 创建一个与不需要 grad 的张量共享存储的张量。它将输出与计算图分离。所以不会沿着这个变量反向传播梯度。

包装器with torch.no_grad() 临时将所有requires_grad 标志设置为false。 torch.no_grad 表示任何操作都不应该构建图。

不同之处在于它只引用一个给定的变量,它被调用。另一个影响with 语句中发生的所有操作。此外,torch.no_grad 将使用更少的内存,因为它从一开始就知道不需要渐变,因此不需要保留中间结果。

通过here 中的示例详细了解它们之间的区别。

【讨论】:

【参考方案2】:

detach()

一个没有detach()的例子:

from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x
r=(y+z).sum()    
make_dot(r)

绿色 r 的最终结果是 AD 计算图的根,蓝色是叶张量。

detach() 的另一个例子:

from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x.detach()
r=(y+z).sum()    
make_dot(r)

这与:

from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x.data
r=(y+z).sum()    
make_dot(r)

但是,x.data 是旧方式(符号),x.detach() 是新方式。

x.detach()有什么区别

print(x)
print(x.detach())

输出:

tensor([1., 1.], requires_grad=True)
tensor([1., 1.])

所以 x.detach() 是一种删除requires_grad 的方法,你得到的是一个新的分离的张量(从AD计算图分离)。

torch.no_grad

torch.no_grad实际上是一个类。

x=torch.ones(2, requires_grad=True)
with torch.no_grad():
    y = x * 2
print(y.requires_grad)

输出:

False

来自help(torch.no_grad)

当您确定时,禁用梯度计算对推理很有用 |你不会打电话给:meth:Tensor.backward()。会减少记忆 |否则将具有requires_grad=True 的计算消耗。 | |在这种模式下,每次计算的结果都会有 | requires_grad=False,即使输入有 requires_grad=True

【讨论】:

感谢您的回答...提供了计算图中 .data 和分离函数的快速直观概述 @prosti AD的完整形式和含义是什么? duckduckgo.com/… ***上的第二个链接【参考方案3】:

一个简单而深刻的解释是,with torch.no_grad() 的使用就像一个循环,其中写入的所有内容都会在其中将requires_grad 参数设置为False,尽管是暂时的。因此,如果您需要停止从某些变量或函数的梯度进行反向传播,则无需指定任何其他内容。

然而,torch.detach() 顾名思义,只是简单地将变量从梯度计算图中分离出来。但是,当必须为有限数量的变量或函数提供此规范时,例如使用此规范。通常在神经网络训练结束后显示损失和准确性输出,因为在那一刻,它只消耗资源,因为它的梯度在结果显示期间无关紧要。

【讨论】:

简单!这是一个很好的回应。

以上是关于PyTorch 中“detach()”和“with torch.nograd()”的区别?的主要内容,如果未能解决你的问题,请参考以下文章

pytorch中copy_()detach()data()和clone()操作区别小结

为啥我们在 Pytorch 张量上调用 .numpy() 之前调用 .detach()?

R语言笔记 attach()detach()和with()

pytorch torch.detach函数(返回一个新的`Variable`,从当前图中分离下来的)

关于pytorch中inplace运算需要注意的问题

pytorch中gather函数的理解。