tensorflow2.0 squeeze出错
Posted lolybj
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow2.0 squeeze出错相关的知识,希望对你有一定的参考价值。
用tf.keras写了自定义层,但在调用自定义层的时候总是报错,找了好久才发现问题所在,所以记下此问题。
问题代码
u=tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel,axis=3)
其中inputs的第一维为None,这里的代码为自定义的前向传播。我是想将得到的输出张量维度为1的维度删掉,因此调用了tf.squeeze方法,这时虽然没有报错但出现了问题。我分别打印了下面内容。
print(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3).shape)
print(self.kernel.shape)
print((tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel).shape)
print(tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1))@self.kernel,axis=3))
可以发现,当张量第一维为None的时候tf.squeeze使结果变为了0。我想要的结果是删除第三个输出的大小为1的维度,即得到下面的结果
解决使用tf.squeeze的时候加上删除的维度。
tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel,axis=3)
以上是关于tensorflow2.0 squeeze出错的主要内容,如果未能解决你的问题,请参考以下文章
在 Tensorflow 2.0 中使用 GradientTape() 和 jacobian() 时出错
《30天吃掉那只 TensorFlow2.0》 2-3 自动微分机制