从 tf.gradients() 到 tf.GradientTape() 的转换返回 None

Posted

技术标签:

【中文标题】从 tf.gradients() 到 tf.GradientTape() 的转换返回 None【英文标题】:Conversion from tf.gradients() to tf.GradientTape() returns None 【发布时间】:2021-05-27 05:08:35 【问题描述】:

我正在将一些 TF1 代码迁移到 TF2。有关完整代码,您可以查看here 行 [155-176]。 TF1 中有一条线在给定损失(浮点值)和 (m, n) 张量的情况下获得梯度

编辑:问题仍然存在

注意: TF2 代码应该兼容并且应该在 tf.function 中工作

g = tf.gradients(-loss, f)  # loss being a float and f being a (m, n) tensor
k = -f_pol / (f + eps)  # f_pol another (m, n) tensor and eps a float
k_dot_g = tf.reduce_sum(k * g, axis=-1)
adj = tf.maximum(
    0.0,
    (tf.reduce_sum(k * g, axis=-1) - delta)
    / (tf.reduce_sum(tf.square(k), axis=-1) + eps),
)
g = g - tf.reshape(adj, [nenvs * nsteps, 1]) * k
grads_f = -g / (nenvs * nsteps)
grads_policy = tf.gradients(f, params, grads_f)  # params being the model parameters

在我正在尝试的 TF2 代码中:

with tf.GradientTape() as tape:
    f = calculate_f()
    f_pol = calculate_f_pol()
    others = do_further_calculations()
    loss = calculate_loss()
g = tape.gradient(-loss, f)

但是,无论我使用 tape.watch(f) 还是创建具有 f 值的 tf.Variable 或什至在 tf.gradients() 中使用 tf.gradients(),我都会不断收到 g = [None],否则它会抱怨。

【问题讨论】:

-loss 是对张量的操作,它应该在磁带上下文中,以便跟踪反向传播。试试loss = -calculate_loss() 然后g = tape.gradient(loss, f),或者如果你更喜欢loss = calculate_loss(); nloss = -loss 然后g = tape.gradient(nloss, f) 你能用一些随机数据添加一个最小的例子吗? :) @Roelant 我已经修复了错误,我怀疑损失计算的某些方面发生在触发错误的tf.GradientTape 上下文之外。 【参考方案1】:

很可能是以下情况之一

    在由@tf.funtion 修饰的函数中定义tf.Variable ? 有些变量是 numpy.array 而不是 tf.Tensor 您更改了修饰函数内部的一些外部变量(即全局变量)。

【讨论】:

tf.function 中定义tf.Variable 通常会引发错误。通过在输入上调用模型两次(这没有多大意义),我能够摆脱这个问题。它可能是2.3。很可能3. 这是一个非常微妙的问题,因为这就像第 10 次从头开始重写代码,我仍然不知道出了什么问题。

以上是关于从 tf.gradients() 到 tf.GradientTape() 的转换返回 None的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow梯度求解tf.gradients

tf.gradients() 是如何工作的?

TensorFlow tf.gradients的用法详细解析以及具体例子

tensorflow-底层梯度tf.AggregationMethod,tf.gradients

如何使用`tf.gradients`? `TypeError: Fetch argument None has invalid type <type 'NoneType'>`

使用 tf.gradients 和 tf.hessian 时出现 TensorFlow 错误:TypeError: Fetch argument None has invalid type <t