建议在 tensorflow 2.0 中调试 `tf.data.Dataset` 操作

Posted

技术标签:

【中文标题】建议在 tensorflow 2.0 中调试 `tf.data.Dataset` 操作【英文标题】:advise for debugging `tf.data.Dataset` operations in tensorflow 2.0 【发布时间】:2019-10-03 01:22:25 【问题描述】:

对于 tf 数据集,Panda 的 df.head() 相当于什么?

按照文档here,我构建了以下玩具示例:

dset = tf.data.Dataset.from_tensor_slices((tf.constant([1.,2.,3.]), tf.constant([4.,4.,4.]), tf.constant([5.,6.,7.])))
print(dset)

输出

<TensorSliceDataset shapes: ((), (), ()), types: (tf.float32, tf.float32, tf.float32)>

我更愿意取回类似于张量的东西,所以为了获得一些值,我将创建一个迭代器。

dset_iter = dset.__iter__()
print(dset_iter.next())

输出

(<tf.Tensor: id=122, shape=(), dtype=float32, numpy=1.0>,
 <tf.Tensor: id=123, shape=(), dtype=float32, numpy=4.0>,
 <tf.Tensor: id=124, shape=(), dtype=float32, numpy=5.0>)

到目前为止一切顺利。让我们尝试一些窗口化...

windowed = dset.window(2)
print(windowed)

输出

<WindowDataset shapes: (<tensorflow.python.data.ops.dataset_ops.DatasetStructure object at 0x1349b25c0>, <tensorflow.python.data.ops.dataset_ops.DatasetStructure object at 0x1349b27b8>, <tensorflow.python.data.ops.dataset_ops.DatasetStructure object at 0x1349b29b0>), types: (<tensorflow.python.data.ops.dataset_ops.DatasetStructure object at 0x1349b25c0>, <tensorflow.python.data.ops.dataset_ops.DatasetStructure object at 0x1349b27b8>, <tensorflow.python.data.ops.dataset_ops.DatasetStructure object at 0x1349b29b0>)>

好的,再次使用迭代器技巧:

windowed_iter = windowed.__iter__()
windowed_iter.next()

输出

(<_VariantDataset shapes: (), types: tf.float32>,
 <_VariantDataset shapes: (), types: tf.float32>,
 <_VariantDataset shapes: (), types: tf.float32>)

什么? WindowDataset 的迭代器返回一个 tuple 其他数据集对象? 我希望这个 WindowDataset 中的第一项是值为 [[1.,4.,5.],[2.,4.,6.]] 的张量。也许这仍然是正确的,但从这个 3 元组数据集中对我来说并不是很明显。 好的。让我们他们的迭代器...

vd = windowed_iter.get_next()
vd0, vd1, vd2 = vd[0], vd[1], vd[2]
vd0i, vd1i, vd2i = vd0.__iter__(), vd1.__iter__(), vd2.__iter__()
print(vd0i.next(), vd1i.next(), vd2i.next())

输出

(<tf.Tensor: id=357, shape=(), dtype=float32, numpy=1.0>,
 <tf.Tensor: id=358, shape=(), dtype=float32, numpy=4.0>,
 <tf.Tensor: id=359, shape=(), dtype=float32, numpy=5.0>)

如您所见,这个工作流程很快变得一团糟。我喜欢 Tf2.0 如何尝试使框架更具交互性和 Pythonic。是否也有符合这一愿景的数据集 api 的好例子?

【问题讨论】:

【参考方案1】:

我也遇到过类似的情况。我最终使用了 zip

train_dataset = train_dataset.window(10, shift=5)
for step_dataset in train_dataset:
    for (images, labels, paths) in zip(*step_dataset):
        train_step(images, labels)

【讨论】:

以上是关于建议在 tensorflow 2.0 中调试 `tf.data.Dataset` 操作的主要内容,如果未能解决你的问题,请参考以下文章

警告:tensorflow:`write_grads` 将在 TensorFlow 2.0 中忽略`TensorBoard` 回调

TensorFlow 2.0 在单 GPU 上训练模型

Keras 2.3.0 发布:支持TensorFlow 2.0!!!!!

初学者的 TensorFlow 2.0 教程

Tensorflow 2.0 最新版(2.4.1) 安装教程

Tensorflow 2.0 最新版(2.4.1) 安装教程