Tensorflow:连接多个tf.Dataset非常慢

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow:连接多个tf.Dataset非常慢相关的知识,希望对你有一定的参考价值。

我在Tensorflow 1.10上

现在我不确定这是不是一个bug。

我一直试图连接我从多个tf.data.Dataset.from_generator生成的100个数据集。

for i in range(1, 100):
        dataset = dataset.concatenate(
            tf.data.Dataset.from_generator(gens[i], (tf.int8, tf.int32), output_shapes=(
                (256, 256), (1))))
        print(i)
 print("before iterator")
 iterator = dataset.make_one_shot_iterator()
 print("after iterator")

运行make_one_shot_iterator()需要很长时间。

有人知道修复吗?

编辑:

看起来_make_dataset.add_to_graph(ops.get_default_graph())似乎被一遍又一遍地调用,导致几百万次调用该函数。 (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/data/ops/dataset_ops.py函数make_one_shot_iterator第162行)

答案

对于像这样的多个张量或生成器来说,运行concatenateis实际上并不是最好的选择。

更好的方法是使用flat_map https://www.tensorflow.org/api_docs/python/tf/data/Dataset#flat_map。我确实更新了示例a,然后展示了如何将它用于多个张量或文件。

以上是关于Tensorflow:连接多个tf.Dataset非常慢的主要内容,如果未能解决你的问题,请参考以下文章

为啥在复制 tf.dataset 时使用 steps_per_epoch?

TensorFlow 数据集 API:缓存

使用tf.train时,使用tf.dataset的Keras model.fit()会失败

利用TF dataset改善模型训练效率的最佳实践

利用TF dataset改善模型训练效率的最佳实践

利用TF dataset改善模型训练效率的最佳实践