Tensorflow compute_output_shape()不适用于自定义层

Posted

技术标签:

【中文标题】Tensorflow compute_output_shape()不适用于自定义层【英文标题】:Tensorflow compute_output_shape() Not Working For Custom Layer 【发布时间】:2018-12-04 08:19:59 【问题描述】:

我在 Keras 中创建了一个自定义层(称为 GraphGather),但输出张量打印为:

Tensor("graph_gather/Tanh:0", shape=(?, ?), dtype=float32)

由于某种原因,形状被返回为 (?,?),这导致下一个密集层引发以下错误:

ValueError:应定义Dense 输入的最后一个维度。找到None

GraphGather层代码如下:

class GraphGather(tf.keras.layers.Layer):

  def __init__(self, batch_size, num_mols_in_batch, activation_fn=None, **kwargs):
    self.batch_size = batch_size
    self.num_mols_in_batch = num_mols_in_batch
    self.activation_fn = activation_fn
    super(GraphGather, self).__init__(**kwargs)

  def build(self, input_shape):
    super(GraphGather, self).build(input_shape)

 def call(self, x, **kwargs):
    # some operations (most of def call omitted)
    out_tensor = result_of_operations() # this line is pseudo code
    if self.activation_fn is not None:
      out_tensor = self.activation_fn(out_tensor)
    out_tensor = out_tensor
    return out_tensor

  def compute_output_shape(self, input_shape):
    return (self.num_mols_in_batch, 2 * input_shape[0][-1])

I have also tried hardcoding compute_output_shape to be: python 定义计算输出形状(自我,输入形状): 返回 (64, 150) ``` 然而打印时的输出张量仍然是

Tensor("graph_gather/Tanh:0", shape=(?, ?), dtype=float32)

这会导致上面写的 ValueError 。


系统信息

已编写自定义代码 **操作系统平台和发行版*:Linux Ubuntu 16.04 TensorFlow 版本(使用下面的命令):1.5.0 Python 版本:3.5.5

【问题讨论】:

也许它期望下一层有相同数量的示例,问题出在批量大小组件的形状上? 【参考方案1】:

我遇到了同样的问题。我的解决方法是在调用方法中添加以下几行:

input_shape = tf.shape(x)

然后:

return tf.reshape(out_tensor, self.compute_output_shape(input_shape))

我还没有遇到任何问题。

【讨论】:

【参考方案2】:

如果 Johnny 的回答不起作用,我发现另一种解决方法是遵循此处的建议 https://github.com/tensorflow/tensorflow/issues/38296#issuecomment-623698709

这是在你的层的输出上调用set_shape方法。

例如

l=GraphGather(...)
y=l(x)
y.set_shape( l.compute_output_shape(x.shape) )

这仅在您使用函数式 API 时有效。

【讨论】:

以上是关于Tensorflow compute_output_shape()不适用于自定义层的主要内容,如果未能解决你的问题,请参考以下文章

如何让 Tensorflow Profiler 在 Tensorflow 2.5 中使用“tensorflow-macos”和“tensorflow-metal”工作

python [test tensorflow] test tensorflow installation #tensorflow

关于tensorflow的显存占用问题

java调用tensorflow训练好的模型

tensorflow新手必看,tensorflow入门教程,tensorflow示例代码

tensorflow 如何在线训练模型