如何将 tf.data.Dataset 与 kedro 一起使用?

Posted

技术标签:

【中文标题】如何将 tf.data.Dataset 与 kedro 一起使用?【英文标题】:How to use tf.data.Dataset with kedro? 【发布时间】:2020-12-23 01:42:24 【问题描述】:

我正在使用tf.data.Dataset 准备用于训练 tf.kears 模型的流数据集。使用kedro,有没有办法创建一个节点并返回创建的tf.data.Dataset,以便在下一个训练节点中使用它?

MemoryDataset 可能不起作用,因为tf.data.Dataset 不能被腌制(deepcopy 是不可能的),另请参阅this SO question。根据issue #91,MemoryDataset 中的深层复制是为了避免其他节点修改数据。有人可以详细说明为什么/如何发生这种并发修改吗?

从docs,似乎有一个copy_mode = "assign"。如果数据不可提取,是否可以使用此选项?

另一个解决方案(在 issue 91 中也提到过)是只使用一个函数在训练节点内生成流tf.data.Dataset,而不需要前面的数据集生成节点。但是,我不确定这种方法的缺点是什么(如果有的话)。如果有人能举一些例子就更好了。

另外,我想避免存储流数据集的完整输出,例如使用 tfrecordstf.data.experimental.save,因为这些选项会占用大量磁盘存储空间。

有没有办法只传递创建的tf.data.Dataset 对象以将其用于训练节点?

【问题讨论】:

【参考方案1】:

在此提供解决方法以造福于社区,尽管它由@DataEngineerOne 在kedro.community 中提出。

根据@DataEngineerOne。

有了kedro,有没有办法创建节点并返回创建的 tf.data.Dataset 在下一个训练节点使用它?

是的,绝对!

有人可以详细说明一下为什么/如何并发 可以修改吗?

从文档中,似乎有一个 copy_mode = "assign" 。可不可能是 如果数据不可提取,是否可以使用此选项?

我还没有尝试过这个选项,但理论上它应该可以工作。您需要做的就是在包含copy_mode 选项的catalog.yml 文件中创建一个新的数据集条目。

例如:

# catalog.yml
tf_data:
  type: MemoryDataSet
  copy_mode: assign

# pipeline.py
node(
  tf_generator,
  inputs=...,
  outputs="tf_data",
)

我无法保证此解决方案,但请尝试一下,让我知道它是否适合您。

另一个解决方案(在 issue 91 中也提到过)是只使用 在训练中生成流式传输 tf.data.Dataset 的函数 节点,没有前面的数据集生成节点。但是,我 我不确定这种方法的缺点是什么(如果有的话)。 如果有人能举一些例子就更好了。

这也是一个很好的替代解决方案,我认为(猜测)MemoryDataSet 在这种情况下会自动使用assign,而不是正常的deepcopy,所以你应该没问题。

# node.py

def generate_tf_data(...):
  tensor_slices = [1, 2, 3]
  def _tf_data():
    dataset = tf.data.Dataset.from_tensor_slices(tensor_slices)
    return dataset
  return _tf_data

def use_tf_data(tf_data_func):
  dataset = tf_data_func()

# pipeline.py
Pipeline([
node(
  generate_tf_data,
  inputs=...,
  outputs='tf_data_func',
),
node(
  use_tf_data,
  inputs='tf_data_func',
  outputs=...
),
])

这里唯一的缺点是额外的复杂性。更多详情可以参考here。

【讨论】:

以上是关于如何将 tf.data.Dataset 与 kedro 一起使用?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 tf.data.Dataset 对象上使用序列/生成器将部分数据放入内存?

tf.data.Dataset.interleave() 与 map() 和 flat_map() 究竟有何不同?

Tensorflow:如何查找 tf.data.Dataset API 对象的大小

如何使用提供的需要 tf.Tensor 的 preprocess_input 函数预处理 tf.data.Dataset?

如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?

如何在 tf.data.Dataset 中输入不同大小的列表列表