tf.shape() 在张量流中得到错误的形状

Posted

技术标签:

【中文标题】tf.shape() 在张量流中得到错误的形状【英文标题】:tf.shape() get wrong shape in tensorflow 【发布时间】:2016-09-02 07:08:21 【问题描述】:

我这样定义一个张量:

x = tf.get_variable("x", [100])

但是当我尝试打印张量的形状时:

print( tf.shape(x) )

我得到Tensor("Shape:0", shape=(1,), dtype=int32),为什么输出的结果不应该是shape=(100)

【问题讨论】:

我发现this 的答案对于分析张量的形状非常有用,尽管它不是公认的。 【参考方案1】:

tf.shape(input, name=None) 返回一个表示输入形状的一维整数张量。

您正在寻找:x.get_shape(),它返回 x 变量的 TensorShape

更新:因为这个答案,我写了一篇文章来阐明 Tensorflow 中的动态/静态形状:https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/

【讨论】:

x.get_shape().as_list() 是一种常用的形式,用于将形状转换为标准的 python 列表。在此添加以供参考。【参考方案2】:

澄清:

tf.shape(x) 创建一个操作并返回一个对象,该对象代表构造操作的输出,这是您当前正在打印的内容。要获取形状,请在会话中运行操作:

matA = tf.constant([[7, 8], [9, 10]])
shapeOp = tf.shape(matA) 
print(shapeOp) #Tensor("Shape:0", shape=(2,), dtype=int32)
with tf.Session() as sess:
   print(sess.run(shapeOp)) #[2 2]

credit:看了上面的答案后,我看到了tf.rank function in Tensorflow 的答案,我发现它更有帮助,我在这里尝试重新措辞。

【讨论】:

【参考方案3】:

只是一个简单的例子,让事情更清楚:

a = tf.Variable(tf.zeros(shape=(2, 3, 4)))
print('-'*60)
print("v1", tf.shape(a))
print('-'*60)
print("v2", a.get_shape())
print('-'*60)
with tf.Session() as sess:
    print("v3", sess.run(tf.shape(a)))
print('-'*60)
print("v4",a.shape)

输出将是:

------------------------------------------------------------
v1 Tensor("Shape:0", shape=(3,), dtype=int32)
------------------------------------------------------------
v2 (2, 3, 4)
------------------------------------------------------------
v3 [2 3 4]
------------------------------------------------------------
v4 (2, 3, 4)

这也应该有帮助: How to understand static shape and dynamic shape in TensorFlow?

【讨论】:

【参考方案4】:

TF FAQ 很好地解释了类似的问题:

在 TensorFlow 中,张量同时具有静态(推断)形状和 动态(真实)形状。静态形状可以使用 tf.Tensor.get_shape 方法:这个形状是从操作中推断出来的 用于创建张量的,可能是部分完整的。如果 静态形状未完全定义,张量 t 的动态形状 可以通过评估tf.shape(t)来确定。

所以tf.shape() 会返回一个张量,其大小始终为shape=(N,),并且可以在会话中计算:

a = tf.Variable(tf.zeros(shape=(2, 3, 4)))
with tf.Session() as sess:
    print sess.run(tf.shape(a))

另一方面,您可以使用x.get_shape().as_list() 提取静态形状,这可以在任何地方计算。

【讨论】:

shape=(N,) 代表什么?当静态形状和动态形状不同时,你能举例说明吗? @mrgloom shape=(n,) 表示大小为 n 的向量。展示这样的例子并不容易,因为你需要将 TF 混淆到足以失去对形状的控制【参考方案5】:

简单地说,使用tensor.shape 得到静态形状

In [102]: a = tf.placeholder(tf.float32, [None, 128])

# returns [None, 128]
In [103]: a.shape.as_list()
Out[103]: [None, 128]

而要获得动态形状,请使用tf.shape()

dynamic_shape = tf.shape(a)

您也可以像在 NumPy 中一样使用your_tensor.shape 获取形状,如下例所示。

In [11]: tensr = tf.constant([[1, 2, 3, 4, 5], [2, 3, 4, 5, 6]])

In [12]: tensr.shape
Out[12]: TensorShape([Dimension(2), Dimension(5)])

In [13]: list(tensr.shape)
Out[13]: [Dimension(2), Dimension(5)]

In [16]: print(tensr.shape)
(2, 5)

另外,这个例子中的张量可以是evaluated。

In [33]: tf.shape(tensr).eval().tolist()
Out[33]: [2, 5]

【讨论】:

【参考方案6】:

Tensorflow 2.0 兼容答案Tensorflow 2.x (>= 2.0) nessuno 解决方案的兼容答案如下所示:

x = tf.compat.v1.get_variable("x", [100])

print(x.get_shape())

【讨论】:

以上是关于tf.shape() 在张量流中得到错误的形状的主要内容,如果未能解决你的问题,请参考以下文章

如何在张量流中使用索引数组?

如何在张量流中张量的某些索引处插入某些值?

为啥我会得到不同形状的张量错误?

张量流中的条件图和访问张量大小的for循环

根据张量流中给定的序列长度数组对 3D 张量进行切片

张量流中推理时的批量标准化