Tensorflow gradientTape 解释

Posted

技术标签:

【中文标题】Tensorflow gradientTape 解释【英文标题】:Tensorflow gradientTape explanation 【发布时间】:2019-10-04 01:36:38 【问题描述】:

我正在尝试从 tensorflow tf.gradientTape 了解 API

下面是我从官网得到的代码:

x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  y = x * x
  z = y * y
dz_dx = g.gradient(z, x)  # 108.0 (4*x^3 at x = 3)
dy_dx = g.gradient(y, x)  # 6.0

我想知道他们是如何将 dz_dx 设为 108 将 dy_dx 设为 6 的?

我还做了如下测试:

x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  y = x * x * x
  z = y * y
dz_dx = g.gradient(z, x)  # 1458.0 
dy_dx = g.gradient(y, x)  # 6.0

这次 dz_dx 变成了 1458,我完全不知道为什么。任何专家可以告诉我如何进行计算吗?

【问题讨论】:

【参考方案1】:

y=x*x,我们可以得到dy/dx=2*x。从z=y*y,我们有dz/dy=2*y。根据链式法则,dz/dx=(dz/dy)*(dy/dx)=(2*y)*(2*x)=(2*x*x)*(2*x)=108dy/dx=2*x=6。您的第二个示例的推导相同。顺便说一句,在您的第二个示例中,dy/dx 应该是 27 而不是 6。

【讨论】:

以上是关于Tensorflow gradientTape 解释的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow 强化学习 RNN 在使用 GradientTape 优化后返回 NaN

[TensorFlow系列-20]:TensorFlow基础 - Varialbe对象的手工求导和半自动链式求导tf.GradientTape

tf.GradientTape详解

如何使用 tf.GradientTape 模拟 ReLU 梯度

Tf 2.0 : RuntimeError: GradientTape.gradient 只能在非持久性磁带上调用一次

Tensorflow 2.0的自定义训练循环的学习率