tensorflow中gradients的使用以及TypeError: Fetch argument None has invalid type <class 'NoneType'

Posted yangzepeng

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow中gradients的使用以及TypeError: Fetch argument None has invalid type <class 'NoneType'相关的知识,希望对你有一定的参考价值。

在反向传播过程中,神经网络需要对每一个loss对应的学习参数求偏导,算出的这个值也就是梯度,用来乘以学习率更新学习参数使用的,它是通过tensorflow中gradients函数使用的。

我们根据官方文档对函数原型进行解析

官方文档中函数原型以及参数如下:

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name=gradients,
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

ys和xs都是张量或者张量列表。函数tf.gradients作用是在ys中对xs求导,求导的返回值是一个list,list的长度与xs的长度相同。

下面通过例子介绍函数的用法(这是李金洪老师那本书中举到的例子)

 

import tensorflow as tf
w1 = tf.Variable([[1,2]])
w2 = tf.Variable([[3,4]])

y = tf.matmul(w1, [[9],[10]])
grads = tf.gradients(y,[w1])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    gradval = sess.run(grads)
    print(gradval)

 

运行这段代码会报错,报错为:

TypeError: Fetch argument None has invalid type <class NoneType>

原因是Tensorflow gradients好像int型的Tensor 的gradients 把w1的设置成float类型的例如tf.float32 gards就能算了,而且tensorflow梯度值一般都是float32类型的。所以我们修改代码将整型的张量改为浮点型:

import tensorflow as tf
w1 = tf.Variable([[1.,2.]])
w2 = tf.Variable([[3.,4.]])

y = tf.matmul(w1, [[9.],[10.]])
grads = tf.gradients(y,[w1])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    gradval = sess.run(grads)
    print(gradval)

输出结果为:

[array([[ 9., 10.]], dtype=float32)]

上面例子中,由于y是由w1与[[9],[10]]相乘而来,所以其导数也就是[[9],[10]](即斜率)。

注意:如果求梯度的式子中没有要求偏导的变量,系统会报错。例如,写成grads = tf.gradients(y,[w1,w2])。

 

以上是关于tensorflow中gradients的使用以及TypeError: Fetch argument None has invalid type <class 'NoneType'的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow tf.gradients的用法详细解析以及具体例子

tensorflow Optimizer.minimize()和gradient clipping

TensorFLow: Gradient Clipping

优化器类中 tensorflow 最小化()函数中的“gate_gradients”属性是啥?

什么时候在 Tensorflow Gradient Tape 中应用 Momentum?

tensorflow-底层梯度tf.AggregationMethod,tf.gradients