如何在 TensorFlow 2.0 中使用 Dataset.window() 方法创建的窗口?

Posted

技术标签:

【中文标题】如何在 TensorFlow 2.0 中使用 Dataset.window() 方法创建的窗口?【英文标题】:How to use windows created by the Dataset.window() method in TensorFlow 2.0? 【发布时间】:2019-08-21 02:17:54 【问题描述】:

我正在尝试使用 TensorFlow 2.0 创建一个数据集,该数据集将返回时间序列中的随机窗口以及作为目标的下一个值。

我正在使用Dataset.window(),看起来很有希望:

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
for window in dataset:
    print([elem.numpy() for elem in window])

输出:

[0, 1, 2, 3, 4]
[1, 2, 3, 4, 5]
[2, 3, 4, 5, 6]
[3, 4, 5, 6, 7]
[4, 5, 6, 7, 8]
[5, 6, 7, 8, 9]

但是,我想使用最后一个值作为目标。如果每个窗口都是张量,我会使用:

dataset = dataset.map(lambda window: (window[:-1], window[-1:]))

但是,如果我尝试这个,我会得到一个异常:

TypeError: '_VariantDataset' object is not subscriptable

【问题讨论】:

【参考方案1】:

解决办法是这样调用flat_map()

dataset = dataset.flat_map(lambda window: window.batch(5))

现在数据集中的每个项目都是一个窗口,所以你可以这样拆分它:

dataset = dataset.map(lambda window: (window[:-1], window[-1:]))

所以完整的代码是:

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(5))
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))

for X, y in dataset:
    print("Input:", X.numpy(), "Target:", y.numpy())

哪些输出:

Input: [0 1 2 3] Target: [4]
Input: [1 2 3 4] Target: [5]
Input: [2 3 4 5] Target: [6]
Input: [3 4 5 6] Target: [7]
Input: [4 5 6 7] Target: [8]
Input: [5 6 7 8] Target: [9]

【讨论】:

即使没有必要回答这个问题,您能否详细说明为什么我们需要这个 flat_map 步骤?我仍然在努力理解它。 window() 方法返回一个包含窗口的数据集,其中每个窗口本身都表示为一个数据集。类似于 1,2,3,4,5,6,7,8,9,10,...,其中 ... 表示数据集。但是我们只需要一个包含张量的常规数据集:[1,2,3,4,5],[6,7,8,9,10],...,其中 [...] 表示一个张量。 flat_map() 方法在转换每个嵌套数据集后返回嵌套数据集中的所有张量。如果我们不进行批处理,我们会得到:1,2,3,4,5,6,7,8,9,10,...。通过将每个窗口批处理为其完整大小,我们得到了我们想要的 [1,2,3,4,5],[6,7,8,9,10],...。清除吗? 有没有办法对这些样品进行小批量生产?我们已经有一个来自 window.batch(5) 的 None 维度,因此在添加例如dataset.batch(3),然后我们得到另一个无维度 好吧,它确实有效,因为 window.batch 调用的 None 维度当然是必要的。 使用flat_map会失去len()的能力,但可以改用len(list(dataset))的慢版本

以上是关于如何在 TensorFlow 2.0 中使用 Dataset.window() 方法创建的窗口?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 TensorFlow 2.0 中使用 Dataset.window() 方法创建的窗口?

如何在 Tensorflow 2.0 中使用 K.get_session 或如何迁移它?

如何在 Tensorflow 2.0 + Keras 中进行并行 GPU 推理?

如何在 tensorflow 2.0 w/keras 中保存/恢复大型模型?

如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强

如何保存使用Tensorflow 1.xx中的.meta检查点模型作为部分的Tensorflow 2.0模型?