torch梯度计算相关
Posted LinXiaoshu
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch梯度计算相关相关的知识,希望对你有一定的参考价值。
torch梯度计算图
计算图中,默认只有叶子结点的梯度能够保留,如果要访问非叶子结点p
的梯度数据,需要执行p.retain_grad()
.
torch计算图中requires_grad
与detach
的区别
requires_grad
是torch.Tensor
中的属性,表示该张量是否需要计算梯度.而detach()
则是方法,将此张量从当前计算图中脱离.这两者的区别在于:调用detach()
后,默认将会把requires_grad
设置为False
.但脱离计算图只是阻断了此张量的梯度向后传播,脱离计算图仍然可以被计算梯度.
比如,在torch.utils.checkpoint
中,有
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
continue
x = inp.detach()
x.requires_grad = inp.requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
这里将inp
从计算图中脱离,却仍然指定其需要梯度,就是要求此梯度被计算,但不传播.
以上是关于torch梯度计算相关的主要内容,如果未能解决你的问题,请参考以下文章
pytorch 笔记:torch.distributions 概率分布相关(更新中)
pytorch 笔记:torch.distributions 概率分布相关(更新中)