提供给 `tf.data.Dataset.from_generator(...)` 的 map 函数可以解析张量对象吗?

Posted

技术标签:

【中文标题】提供给 `tf.data.Dataset.from_generator(...)` 的 map 函数可以解析张量对象吗?【英文标题】:Can the map function supplied to `tf.data.Dataset.from_generator(...)` resolve a tensor object? 【发布时间】:2018-06-29 04:26:34 【问题描述】:

我想创建一个tf.data.Dataset.from_generator(...) 数据集。我需要传入一个 Python generator

我想将前一个数据集的属性传递给生成器,如下所示:

dataset = dataset.interleave(
  map_func=lambda x: tf.data.Dataset.from_generator(generator=lambda: gen(x), output_types=tf.int64),
  cycle_length=2
)

我在哪里定义 gen(...) 以获取一个值(这是指向某些数据的指针,例如 gen 知道如何访问的文件名)。

这失败了,因为gen 接收到张量对象,而不是 python/numpy 值。

有没有办法将张量对象解析为gen(...) 内的值?

交错生成器的原因是我可以使用其他数据集操作(例如 .shuffle().repeat())来操作数据指针/文件名列表,而无需将它们烘焙到 gen(...) 函数中,这将如果我直接从数据指针/文件名列表开始使用生成器,则有必要。

我想使用生成器,因为每个数据指针/文件名都会生成大量数据值。

【问题讨论】:

看来这里的答案是否定的,有 tf.py_func 为地图函数提供此功能,但 tf.py_func 不适用于生成器。如果有更多信息曝光,我会再开放一段时间。这似乎是对数据集管道流程的严格限制。 【参考方案1】:

TensorFlow 现在支持将张量参数传递给生成器:

def map_func(tensor):
    dataset = tf.data.Dataset.from_generator(generator, tf.float32, args=(tensor,))
    return dataset

【讨论】:

感谢您添加答案,我还没有注意到。这是个好消息!【参考方案2】:

答案确实是否定的。以下是对几个相关 git 问题(截至撰写本文时开放)的参考,以了解该问题的进一步发展:

https://github.com/tensorflow/tensorflow/issues/13101

https://github.com/tensorflow/tensorflow/issues/16343

【讨论】:

以上是关于提供给 `tf.data.Dataset.from_generator(...)` 的 map 函数可以解析张量对象吗?的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 tf.data.Dataset.from_generator() 向生成器函数发送参数?

tf.data.Dataset.from_tensor_slices() 详解

tf.data.Dataset.from_tensor_slices,张量和渴望模式

tf.data.Dataset.from_tensor_slices中的shuffle()repeat()batch()用法

Tensorflow:连接多个tf.Dataset非常慢

有没有一种简单的方法可以在 tensorflow 中使用 tf.data.Dataset.from_generator 和自定义 model_fn(Estimator) 中的功能