使用 [:, :, 0] 的 TF2 / Keras 切片张量
Posted
技术标签:
【中文标题】使用 [:, :, 0] 的 TF2 / Keras 切片张量【英文标题】:TF2 / Keras slice tensor using [:, :, 0] 【发布时间】:2019-12-17 17:00:18 【问题描述】:在 TF 2.0 Beta 中我正在尝试:
x = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
print(x.shape) # (None, 240, 2)
a = x[:, :, 0]
print(a.shape) # <unknown>
在 TF 1.x 中我可以做到:
x = tf1.placeholder(tf1.float32, (None, 240, 2)
a = x[:, :, 0]
它会正常工作。如何在 TF 2.0 中实现这一点?我觉得
tf.split(x, 2, axis=2)
可能有效,但是我想使用切片而不是硬编码 2(轴 2 的暗淡)。
【问题讨论】:
【参考方案1】:不同之处在于Input
返回的对象代表一个层,而不是任何类似于占位符或张量的对象。所以上面 tf 2.0 代码中的 x
是一个图层对象,而 tf 1.x 代码中的 x
是张量的占位符。
您可以定义一个切片层来执行该操作。有现成可用的层,但对于像这样的简单切片,Lambda
层非常易于阅读,并且可能最接近您在 tf 1.x 中习惯的切片方式。
类似这样的:
input_lyr = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
sliced_lyr = tf.keras.layers.Lambda(lambda x: x[:,:,0])
你可以像这样在你的 keras 模型中使用它:
model = tf.keras.models.Sequential([
input_lyr,
sliced_lyr,
# ...
# <other layers>
# ...
])
当然,以上是特定于 keras 模型的。相反,如果你有一个张量而不是一个 keras 层对象,那么切片就像以前一样工作。像这样的:
my_tensor = tf.random.uniform((8,240,2))
sliced = my_tensor[:,:,0]
print(my_tensor.shape)
print(sliced.shape)
输出:
(8, 240, 2)
(8, 240)
如预期的那样
【讨论】:
以上是关于使用 [:, :, 0] 的 TF2 / Keras 切片张量的主要内容,如果未能解决你的问题,请参考以下文章
使用 TF2.0 训练 RNN 的每次迭代逐渐增加内存使用量
使用 dropout (TF2.0) 时,可变批量大小不适用于 tf.keras.layers.RNN?