使用 [:, :, 0] 的 TF2 / Keras 切片张量

Posted

技术标签:

【中文标题】使用 [:, :, 0] 的 TF2 / Keras 切片张量【英文标题】:TF2 / Keras slice tensor using [:, :, 0] 【发布时间】:2019-12-17 17:00:18 【问题描述】:

在 TF 2.0 Beta 中我正在尝试:

x = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
print(x.shape) # (None, 240, 2)
a = x[:, :, 0]
print(a.shape) # <unknown>

在 TF 1.x 中我可以做到:

x = tf1.placeholder(tf1.float32, (None, 240, 2)
a = x[:, :, 0]

它会正常工作。如何在 TF 2.0 中实现这一点?我觉得

tf.split(x, 2, axis=2)

可能有效,但是我想使用切片而不是硬编码 2(轴 2 的暗淡)。

【问题讨论】:

【参考方案1】:

不同之处在于Input 返回的对象代表一个层,而不是任何类似于占位符或张量的对象。所以上面 tf 2.0 代码中的 x 是一个图层对象,而 tf 1.x 代码中的 x 是张量的占位符。

您可以定义一个切片层来执行该操作。有现成可用的层,但对于像这样的简单切片,Lambda 层非常易于阅读,并且可能最接近您在 tf 1.x 中习惯的切片方式。

类似这样的:

input_lyr = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
sliced_lyr = tf.keras.layers.Lambda(lambda x: x[:,:,0])

你可以像这样在你的 keras 模型中使用它:

model = tf.keras.models.Sequential([
    input_lyr,
    sliced_lyr,
    # ...
    # <other layers>
    # ...
])

当然,以上是特定于 keras 模型的。相反,如果你有一个张量而不是一个 keras 层对象,那么切片就像以前一样工作。像这样的:

my_tensor = tf.random.uniform((8,240,2))
sliced = my_tensor[:,:,0]

print(my_tensor.shape)
print(sliced.shape)

输出:

(8, 240, 2)
(8, 240)

如预期的那样

【讨论】:

以上是关于使用 [:, :, 0] 的 TF2 / Keras 切片张量的主要内容,如果未能解决你的问题,请参考以下文章

使用 TF2.0 训练 RNN 的每次迭代逐渐增加内存使用量

使用 dropout (TF2.0) 时,可变批量大小不适用于 tf.keras.layers.RNN?

『TensorFlow2.0正式版教程』极简安装TF2.0正式版(CPU&GPU)教程

拥抱TF2.0的时代来了

用TF2.0 复现经典推荐系统论文

[转帖]谷歌TF2.0凌晨发布!“改变一切,力压PyTorch”