『PyTorch』第五弹_深入理解autograd_下:Variable梯度探究

Posted 叠加态的猫

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了『PyTorch』第五弹_深入理解autograd_下:Variable梯度探究相关的知识,希望对你有一定的参考价值。

查看非叶节点梯度的两种方法

在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:

  • 使用autograd.grad函数
  • 使用hook

autograd.gradhook方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用hook方法,但是在实际使用中应尽量避免修改grad的值。

求z对y的导数

x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

# hook
# hook没有返回值,参数是函数,函数的参数是梯度值
def variable_hook(grad):
    print("hook梯度输出:\r\n",grad)

hook_handle = y.register_hook(variable_hook)         # 注册hook
z.backward(retain_graph=True)                        # 内置输出上面的hook
hook_handle.remove()                                 # 释放

print("autograd.grad输出:\r\n",t.autograd.grad(z,y)) # t.autograd.grad方法
hook梯度输出:
 Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

autograd.grad输出:
 (Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]
,)

 

多次反向传播试验

实际就是使用retain_graph参数,

# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

z.backward(retain_graph=True)
print(w.grad)
z.backward()
print(w.grad)
Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

Variable containing:
 2
 2
 2
[torch.FloatTensor of size 3]

 

如果不使用retain_graph参数,

  • 实际上效果是一样的,AccumulateGrad object仍然会积累梯度
  • 除了叶子节点之外,高层节点需要重新定义,因为原图已经传播了,需要基于原叶子建立新图,实际上第二次的z.backward()已经不是第一次的z所在的图了,这里看似简单,实际上体现了动态图的技术,静态图初始化之后会留在内存中等待feed数据,但是动态图不会,反向传播后就已经被废弃,下次要么完全重建(如下),要么反向传播之后指定不舍弃图z.backward(retain_graph=True),总之和常规的数据结构不同,图上的节点是隶属于图的属性的,TensorFlow中会一直存留,PyTorch中就会backward后直接舍弃(默认时)。
# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

z.backward()
print(w.grad)
y = w.mul(x)  # <-----
z = y.sum()  # <-----
z.backward()
print(w.grad)
Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

Variable containing:
 2
 2
 2
[torch.FloatTensor of size 3]

 

以上是关于『PyTorch』第五弹_深入理解autograd_下:Variable梯度探究的主要内容,如果未能解决你的问题,请参考以下文章

『PyTorch』第五弹_深入理解autograd_下:Variable梯度探究

『PyTorch』第五弹_深入理解Tensor对象_中上:索引

『PyTorch』第五弹_深入理解Tensor对象_上:初始化以及尺寸调整

『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor

『PyTorch』第五弹_深入理解Tensor对象_中下:数学计算以及numpy比较

C语言期末第五弹