PyTorch 中的后向函数
Posted
技术标签:
【中文标题】PyTorch 中的后向函数【英文标题】:Backward function in PyTorch 【发布时间】:2019-12-06 11:52:51 【问题描述】:我对 pytorch 的后向功能有一些疑问,我认为我没有得到正确的输出:
import numpy as np
import torch
from torch.autograd import Variable
a = Variable(torch.FloatTensor([[1,2,3],[4,5,6]]), requires_grad=True)
out = a * a
out.backward(a)
print(a.grad)
输出是
tensor([[ 2., 8., 18.],
[32., 50., 72.]])
也许是2*a*a
但我认为输出应该是
tensor([[ 2., 4., 6.],
[8., 10., 12.]])
2*a.
导致d(x^2)/dx=2x
【问题讨论】:
【参考方案1】:请仔细阅读backward()
上的文档以更好地理解它。
默认情况下,pytorch 期望为网络的 last 输出调用 backward()
- 损失函数。损失函数总是输出一个标量,因此 scalar 损失与所有其他变量/参数的梯度是明确定义的(使用链式法则)。
因此,默认情况下,backward()
在标量张量上调用并且不需要任何参数。
例如:
a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
for j in range(3):
out = a[i,j] * a[i,j]
out.backward()
print(a.grad)
产量
tensor([[ 2., 4., 6.], [ 8., 10., 12.]])
如预期:d(a^2)/da = 2a
。
但是,当您在 2×3 out
张量(不再是标量函数)上调用 backward
时,您期望 a.grad
是什么?你实际上需要一个 2×3×2×3 的输出:d out[i,j] / d a[k,l]
(!)
Pytorch 不支持这种非标量函数导数。相反,pytorch 假设 out
只是一个中间张量,并且在“上游”某处有一个标量损失函数,通过链式规则提供 d loss/ d out[i,j]
。这个“上游”渐变的大小是 2×3,在这种情况下,这实际上是您提供的 backward
参数:out.backward(g)
where g_ij = d loss/ d out_ij
。
然后通过链式法则d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])
计算梯度
因为您提供了 a
作为“上游”渐变,所以您得到了
a.grad[i,j] = 2 * a[i,j] * a[i,j]
如果您要提供“上游”渐变为全1
out.backward(torch.ones(2,3))
print(a.grad)
产量
tensor([[ 2., 4., 6.], [ 8., 10., 12.]])
正如预期的那样。
这一切都在链式法则中。
【讨论】:
我意识到这是一篇相对较旧的帖子,但换句话说,它正在计算向量雅可比积 a * (2 * a) 对吗? 当您说梯度是通过链式法则计算时,您忘记了矩阵乘法。所以不只是 (d loss/d out[i,j]) * (d out[i,j] / da[i,j]),实际上是 sum_k,l (d loss/d out[ k,l]) * (d out[k,l] / da[i,j]).以上是关于PyTorch 中的后向函数的主要内容,如果未能解决你的问题,请参考以下文章
如何在 PyTorch 中使用 autograd.gradcheck?