tensorflow2.0 squeeze出错

Posted lolybj

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow2.0 squeeze出错相关的知识,希望对你有一定的参考价值。

用tf.keras写了自定义层,但在调用自定义层的时候总是报错,找了好久才发现问题所在,所以记下此问题。

问题代码

u=tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel,axis=3)

其中inputs的第一维为None,这里的代码为自定义的前向传播。我是想将得到的输出张量维度为1的维度删掉,因此调用了tf.squeeze方法,这时虽然没有报错但出现了问题。我分别打印了下面内容。

print(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3).shape)
print(self.kernel.shape)
print((tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel).shape)
print(tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1))@self.kernel,axis=3))

技术图片

可以发现,当张量第一维为None的时候tf.squeeze使结果变为了0。我想要的结果是删除第三个输出的大小为1的维度,即得到下面的结果

技术图片

解决使用tf.squeeze的时候加上删除的维度。

tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel,axis=3)

以上是关于tensorflow2.0 squeeze出错的主要内容,如果未能解决你的问题,请参考以下文章

在 Tensorflow 2.0 中使用 GradientTape() 和 jacobian() 时出错

《30天吃掉那只 TensorFlow2.0》 2-3 自动微分机制

TensorFlow2.0--TensorFlow2.0构架

请注意更新TensorFlow 2.0的旧代码

Tensorflow2.0笔记

Tensorflow2.0笔记