Tensorflow gradientTape 解释
Posted
技术标签:
【中文标题】Tensorflow gradientTape 解释【英文标题】:Tensorflow gradientTape explanation 【发布时间】:2019-10-04 01:36:38 【问题描述】:我正在尝试从 tensorflow tf.gradientTape 了解 API
下面是我从官网得到的代码:
x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
g.watch(x)
y = x * x
z = y * y
dz_dx = g.gradient(z, x) # 108.0 (4*x^3 at x = 3)
dy_dx = g.gradient(y, x) # 6.0
我想知道他们是如何将 dz_dx 设为 108 将 dy_dx 设为 6 的?
我还做了如下测试:
x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
g.watch(x)
y = x * x * x
z = y * y
dz_dx = g.gradient(z, x) # 1458.0
dy_dx = g.gradient(y, x) # 6.0
这次 dz_dx 变成了 1458,我完全不知道为什么。任何专家可以告诉我如何进行计算吗?
【问题讨论】:
【参考方案1】:从y=x*x
,我们可以得到dy/dx=2*x
。从z=y*y
,我们有dz/dy=2*y
。根据链式法则,dz/dx=(dz/dy)*(dy/dx)=(2*y)*(2*x)=(2*x*x)*(2*x)=108
。 dy/dx=2*x=6
。您的第二个示例的推导相同。顺便说一句,在您的第二个示例中,dy/dx
应该是 27 而不是 6。
【讨论】:
以上是关于Tensorflow gradientTape 解释的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow 强化学习 RNN 在使用 GradientTape 优化后返回 NaN
[TensorFlow系列-20]:TensorFlow基础 - Varialbe对象的手工求导和半自动链式求导tf.GradientTape
如何使用 tf.GradientTape 模拟 ReLU 梯度