在 tf.data 中切片导致“在图形执行中不允许迭代 `tf.Tensor`”错误

Posted

技术标签:

【中文标题】在 tf.data 中切片导致“在图形执行中不允许迭代 `tf.Tensor`”错误【英文标题】:Slicing in tf.data causes "iterating over `tf.Tensor` is not allowed in Graph execution" error 【发布时间】:2021-07-03 23:10:20 【问题描述】:

我有一个如下创建的数据集,其中image_train_path 是图像文件路径的列表, 例如。 [b'/content/drive/My Drive/data/folder1/im1.png', b'/content/drive/My Drive/data/folder2/im6.png',...]。我需要提取文件夹路径,例如'/content/drive/My Drive/data/folder1',然后进行一些其他操作。我尝试使用preprocessData 函数来执行此操作,如下所示。

dataset = tf.data.Dataset.from_tensor_slices(image_train_path)
dataset = dataset.map(preprocessData, num_parallel_calls=16)

preprocessData 在哪里:

def preprocessData(images_path):
    folder=tf.strings.split(images_path,'/')
    foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
    ....

但是,切片线会导致以下错误:

OperatorNotAllowedInGraphError: in user code:

    <ipython-input-21-2a9827982c16>:4 preprocessData  *
        foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:210 wrapper  **
        result = dispatch(wrapper, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:122 dispatch
        result = dispatcher.handle(args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/ragged/ragged_dispatch.py:130 handle
        for elt in x:
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:524 __iter__
        self._disallow_iteration()
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:520 _disallow_iteration
        self._disallow_in_graph_mode("iterating over `tf.Tensor`")
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:500 _disallow_in_graph_mode
        " this function with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

我在 Tf2.4 和 tf nightly 中都试过这个。我尝试使用@tf.function 以及tf.data.experimental.enable_debug_mode() 进行装饰。总是给出同样的错误。

我不太明白是哪个部分导致了“迭代”,尽管我猜问题是切片。有没有其他方法可以做到这一点?

【问题讨论】:

能把preprocessData的完整代码贴出来吗? 【参考方案1】:

函数tf.strings.join 需要一个张量列表,如文档所述:

参数

输入:具有相同大小和 tf.string dtype 的 tf.Tensor 对象列表

tf.slice 返回一个 Tensor,然后 join 函数会尝试对其进行迭代,从而导致错误。

您可以使用简单的列表理解来提供函数:

def preprocessData(images_path):
    folder=tf.strings.split(images_path,'/')
    foldername=tf.strings.join([folder[i] for i in range(6)],"/")
    return foldername

【讨论】:

以上是关于在 tf.data 中切片导致“在图形执行中不允许迭代 `tf.Tensor`”错误的主要内容,如果未能解决你的问题,请参考以下文章

如何将 sample_weights 与 3D 医疗数据一起使用,而没有 model.fit(x=tf.data.Dataset) 导致无法挤压最后一个暗淡等错误

TensorFlow.org教程笔记 DataSets 快速入门

建议在 tensorflow 2.0 中调试 `tf.data.Dataset` 操作

如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?

如何在 tf.data.Dataset 中输入不同大小的列表列表

如何在 keras 自定义回调中访问 tf.data.Dataset?