如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强

Posted

技术标签:

【中文标题】如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强【英文标题】:How to apply data augmentation in TensorFlow 2.0 after tfds.load() 【发布时间】:2019-08-04 01:58:15 【问题描述】:

我正在关注this guide。

它展示了如何使用tfds.load() 方法从新的 TensorFlow 数据集中下载数据集:

import tensorflow_datasets as tfds    
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs', split=list(splits),
    with_info=True, as_supervised=True)

接下来的步骤展示了如何使用 map 方法将函数应用于数据集中的每个项目:

def format_example(image, label):
    image = tf.cast(image, tf.float32)
    image = image / 255.0
    # Resize the image if required
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return image, label

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

然后访问我们可以使用的元素:

for features in ds_train.take(1):
  image, label = features["image"], features["label"]

for example in tfds.as_numpy(train_ds):
  numpy_images, numpy_labels = example["image"], example["label"]

但是,该指南没有提及任何有关数据增强的内容。我想使用类似于 Keras 的 ImageDataGenerator 类的实时数据增强。我尝试使用:

if np.random.rand() > 0.5:
    image = tf.image.flip_left_right(image)

以及format_example() 中的其他类似增强功能,但是,我如何验证它正在执行实时增强而不是替换数据集中的原始图像?

我可以通过将batch_size=-1 传递给tfds.load() 将完整的数据集转换为Numpy 数组,然后使用tfds.as_numpy() 但是,这会将所有不需要的图像加载到内存中。我应该能够使用train = train.prefetch(tf.data.experimental.AUTOTUNE) 为下一个训练循环加载足够的数据。

【问题讨论】:

您可能还想查看this answer,它会在增强后呈现数据,因此您可以更加确定它正在工作(而且这个例子更令人信服)。 【参考方案1】:

您从错误的方向解决问题。

首先,例如使用tfds.loadcifar10 下载数据(为简单起见,我们将使用默认的TRAINTEST 拆分):

import tensorflow_datasets as tfds

dataloader = tfds.load("cifar10", as_supervised=True)
train, test = dataloader["train"], dataloader["test"]

(您可以使用自定义tfds.Split 对象来创建验证数据集或其他see documentation)

traintesttf.data.Dataset 对象,因此您可以使用 mapapplybatch 和类似的函数。

下面是一个例子,我会(主要使用tf.image):

将每个图像转换为0-1范围内的tf.float64(不要使用官方文档中的这个愚蠢的sn-p,这样可以确保正确的图像格式) cache() 结果可以在每个 repeat 之后重复使用 随机翻转left_to_right每张图片 随机改变图像对比度 随机数据和批处理 重要提示:在数据集用完时重复所有步骤。这意味着在一个 epoch 之后,所有上述转换都会再次应用(缓存的转换除外)。

这是执行上述操作的代码(您可以将lambdas 更改为仿函数或函数):

train = train.map(
    lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
).cache().map(
    lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
    lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
    100
).batch(
    64
).repeat()

这样的tf.data.Dataset可以直接传递给Keras的fitevaluatepredict方法。

验证它实际上是这样工作的

我看你对我的解释很怀疑,我们来看一个例子:

1。获取一小部分数据

这是获取单个元素的一种方法,公认不可读且不直观,但如果您使用 Tensorflow 做任何事情,您应该可以接受:

# Horrible API is horrible
element = tfds.load(
    # Take one percent of test and take 1 element from it
    "cifar10",
    as_supervised=True,
    split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
).take(1)

2。重复数据,检查是否相同:

使用Tensorflow 2.0 实际上可以做到这一点而无需愚蠢的解决方法(几乎):

element = element.repeat(2)
# You can iterate through tf.data.Dataset now, finally...
images = [image[0] for image in element]
print(f"Are the same: tf.reduce_all(tf.equal(images[0], images[1]))")

不出所料地返回:

Are the same: True

3。通过随机增强检查每次重复后数据是否不同

下面sn -p repeats 5次单个元素,检查哪些相等哪些不同。

element = (
    tfds.load(
        # Take one percent of test and take 1 element
        "cifar10",
        as_supervised=True,
        split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
    )
    .take(1)
    .map(lambda image, label: (tf.image.random_flip_left_right(image), label))
    .repeat(5)
)

images = [image[0] for image in element]

for i in range(len(images)):
    for j in range(i, len(images)):
        print(
            f"i same as j: tf.reduce_all(tf.equal(images[i], images[j]))"
        )

输出(在我的情况下,每次运行都会不同):

0 same as 0: True
0 same as 1: False
0 same as 2: True
0 same as 3: False
0 same as 4: False
1 same as 1: True
1 same as 2: False
1 same as 3: True
1 same as 4: True
2 same as 2: True
2 same as 3: False
2 same as 4: False
3 same as 3: True
3 same as 4: True
4 same as 4: True

您也可以将这些图像中的每一个都投射到numpy,并使用skimage.io.imshowmatplotlib.pyplot.imshow 或其他替代方法亲自查看这些图像。

实时数据增强可视化的另一个例子

This answer 使用TensorboardMNIST 提供了一个更全面、更易读的数据增强视图,可能想检查一下(是的,无耻的插件,但我猜很有用)。

【讨论】:

来自映射函数 here 的文档:此转换将 map_func 应用于此数据集的每个元素,并返回一个包含转换后元素的新数据集,其顺序与它们出现的顺序相同输入。 确实如此。检查我刚刚添加的重要提示: 部分。基本上,每个增强都应用于数据的每个部分(在这种情况下,单个元素,如果在它之前使用batch(),则可以是批处理,这样应该更快)并且它在有或没有增强的情况下返回(如果是随机的) )。当tf.data.Dataset 用尽并使用repeat(为了训练多个时期/无限期)时,所有操作都会重复,除了我们在第一次通过时缓存的那些。它消除了混乱吗? 好的,我使用repeat时如何验证所有操作是否重复? 我看你对tensorflow不太信任,我能理解。我添加了一个比较random_flip_left_right 之前和之后的图像的示例。如果您愿意,您可以通过这种方式进行自己的更广泛的测试。 感谢您的示例!验证步骤之后,事情就清楚多了。

以上是关于如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强的主要内容,如果未能解决你的问题,请参考以下文章

如何使用张量流数据集 (TDFS) 作为张量流模型的输入?

将自定义数据加载到张量流管道中

Tensorflow2数据增强(data_augmentation)代码

TensorFlow 的 ./configure 在哪里以及如何启用 GPU 支持?

Tensorflow:您如何在模型训练期间实时监控 GPU 性能?

TensorFlow:在PyCharm中配置TensorFlow