Tensorflow compute_output_shape()不适用于自定义层
Posted
技术标签:
【中文标题】Tensorflow compute_output_shape()不适用于自定义层【英文标题】:Tensorflow compute_output_shape() Not Working For Custom Layer 【发布时间】:2018-12-04 08:19:59 【问题描述】:我在 Keras 中创建了一个自定义层(称为 GraphGather),但输出张量打印为:
Tensor("graph_gather/Tanh:0", shape=(?, ?), dtype=float32)
由于某种原因,形状被返回为 (?,?),这导致下一个密集层引发以下错误:
ValueError:应定义
Dense
输入的最后一个维度。找到None
。
GraphGather层代码如下:
class GraphGather(tf.keras.layers.Layer):
def __init__(self, batch_size, num_mols_in_batch, activation_fn=None, **kwargs):
self.batch_size = batch_size
self.num_mols_in_batch = num_mols_in_batch
self.activation_fn = activation_fn
super(GraphGather, self).__init__(**kwargs)
def build(self, input_shape):
super(GraphGather, self).build(input_shape)
def call(self, x, **kwargs):
# some operations (most of def call omitted)
out_tensor = result_of_operations() # this line is pseudo code
if self.activation_fn is not None:
out_tensor = self.activation_fn(out_tensor)
out_tensor = out_tensor
return out_tensor
def compute_output_shape(self, input_shape):
return (self.num_mols_in_batch, 2 * input_shape[0][-1])
I have also tried hardcoding compute_output_shape to be:
python
定义计算输出形状(自我,输入形状):
返回 (64, 150)
```
然而打印时的输出张量仍然是
Tensor("graph_gather/Tanh:0", shape=(?, ?), dtype=float32)
这会导致上面写的 ValueError 。
系统信息
已编写自定义代码 **操作系统平台和发行版*:Linux Ubuntu 16.04 TensorFlow 版本(使用下面的命令):1.5.0 Python 版本:3.5.5【问题讨论】:
也许它期望下一层有相同数量的示例,问题出在批量大小组件的形状上? 【参考方案1】:我遇到了同样的问题。我的解决方法是在调用方法中添加以下几行:
input_shape = tf.shape(x)
然后:
return tf.reshape(out_tensor, self.compute_output_shape(input_shape))
我还没有遇到任何问题。
【讨论】:
【参考方案2】:如果 Johnny 的回答不起作用,我发现另一种解决方法是遵循此处的建议 https://github.com/tensorflow/tensorflow/issues/38296#issuecomment-623698709
这是在你的层的输出上调用set_shape
方法。
例如
l=GraphGather(...)
y=l(x)
y.set_shape( l.compute_output_shape(x.shape) )
这仅在您使用函数式 API 时有效。
【讨论】:
以上是关于Tensorflow compute_output_shape()不适用于自定义层的主要内容,如果未能解决你的问题,请参考以下文章
如何让 Tensorflow Profiler 在 Tensorflow 2.5 中使用“tensorflow-macos”和“tensorflow-metal”工作
python [test tensorflow] test tensorflow installation #tensorflow