tensorflow dataset, tfds的数据加载中的数据增广,导致TypError或AttributeError

Posted

技术标签:

【中文标题】tensorflow dataset, tfds的数据加载中的数据增广,导致TypError或AttributeError【英文标题】:Data augmentation in the data loading of tensorflow dataset, tfds, resulting in TypError or AttributeError 【发布时间】:2021-12-14 14:30:54 【问题描述】:

我正在尝试做一些数据增强,但我对张量不太熟悉。 这是我开始的代码:

def _random_apply(func, x, p):
  return tf.cond(tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
                  tf.cast(p, tf.float32)),
        lambda: func(x),
        lambda: x)
def _resize_with_pad(image):
  image = tf.image.resize_with_pad(image, target_height=IMG_S, target_width=IMG_S)   
  return image

def augment(image, label):
  img = _random_apply(tf.image.flip_left_right(image), image, p=0.2)
  img = _random_apply(_resize_with_pad(img), img, p=1)
  return img, label
train_dataset = (
    train_ds
    .shuffle(1000)
    .map(augment, num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(tf.data.AUTOTUNE)
)

导致以下错误。

----> 4     .map(augment, num_parallel_calls=tf.data.AUTOTUNE)

TypeError: 'Tensor' object is not callable

然后我想如果我将它转换为 numpy 可能会起作用。

def _random_apply(func, x, p):
  return tf.cond(tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
                  tf.cast(p, tf.float32)),
        lambda: func(x),
        lambda: x)
def _resize_with_pad(image):
  image = image.numpy()
  image = tf.image.resize_with_pad(image, target_height=IMG_S, target_width=IMG_S).numpy()  
  return image

def augment(image, label):
  image = image.numpy()
  img = _random_apply(tf.image.flip_left_right(image).numpy(), image, p=0.2)
  img = _random_apply(_resize_with_pad(img), img, p=1)
  return img, label
train_dataset = (
    train_ds
    .shuffle(1000)
    .map(augment, num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(tf.data.AUTOTUNE)
)

但现在我得到了这个错误。

----> 4     .map(augment, num_parallel_calls=tf.data.AUTOTUNE)

 AttributeError: 'Tensor' object has no attribute 'numpy'

我尝试在answer 中执行类似操作,现在我直接没有收到错误,而是在下一个代码块中:

for image, _ in train_dataset.take(9):
etc
InvalidArgumentError 
----> 1 for image, _ in train_dataset.take(9):

InvalidArgumentError: TypeError: 'tensorflow.python.framework.ops.EagerTensor' object is not callable

有人知道我做错了什么吗?

【问题讨论】:

【参考方案1】:

augment 中,您将张量传递给_random_applytf.image.flip_left_right(image) 返回一个张量。然后,在_random_apply 中,您将使用该张量,就像它是一个函数一样。您需要将tf.flip_left_right 作为可调用对象传递:

def augment(image):
    img = _random_apply(tf.image.flip_left_right, image, p=0.2)
    img = _random_apply(_resize_with_pad, img, p=1)
    return img

完整的工作示例:

import tensorflow as tf

train_ds = tf.data.Dataset.from_tensor_slices(tf.random.uniform((100, 224, 224, 3)))


def _random_apply(func, x, p):
    return tf.cond(tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
                           tf.cast(p, tf.float32)),
                   lambda: func(x),
                   lambda: x)


def _resize_with_pad(image):
    image = tf.image.resize_with_pad(image, target_height=200, target_width=200)
    return image


def augment(image):
    img = _random_apply(tf.image.flip_left_right, image, p=0.2)
    img = _random_apply(_resize_with_pad, img, p=1)
    return img


train_dataset = train_ds.map(augment)

batch = next(iter(train_dataset))

【讨论】:

以上是关于tensorflow dataset, tfds的数据加载中的数据增广,导致TypError或AttributeError的主要内容,如果未能解决你的问题,请参考以下文章

如何使用 tensorflow 数据集访问图像

如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强

将自定义数据加载到张量流管道中

Tensorflow2数据增强(data_augmentation)代码

tensorflow Dataset及TFRecord一些要点持续更新

Tensorflow:连接多个tf.Dataset非常慢