如何使用 Tensorflow 2.0 数据集在训练时执行 10 次裁剪图像增强

Posted

技术标签:

【中文标题】如何使用 Tensorflow 2.0 数据集在训练时执行 10 次裁剪图像增强【英文标题】:How to perform 10 Crop Image Augmentation at training time using Tensorflow 2.0 Dataset 【发布时间】:2020-01-03 09:59:33 【问题描述】:

我正在使用 Tensorflow Dataset API 并从 TFRecord 文件中读取数据。我可以使用 map 函数并使用 random_flip_left_right、random_crop 等方法进行数据增强。

但是,当我尝试复制 AlexNet 论文时,我遇到了一个问题。我需要翻转每张图像,然后进行 5 次裁剪(左、上、下、右和中)。

所以输入数据集的大小会增加 10 倍。无论如何使用tensorflow数据集API来做到这一点? map() 函数只返回一张图像,我无法增加图像的数量。

请看我现在的代码。

dataset = dataset.map(parse_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
    .map(lambda image, label: (tf.image.random_flip_left_right(image), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) \
    .map(lambda image, label: (tf.image.random_crop(image, size=[227, 227, 3]), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) \
    .shuffle(buffer_size=1000) \
    .repeat() \
    .batch(256) \
    .prefetch(tf.data.experimental.AUTOTUNE)

【问题讨论】:

【参考方案1】:
def tile_crop(img, label):
    img_shape = tf.shape(img)
    crop_left = lambda img: tf.image.random_crop(img[:,:img_shape[1]//2,:], size=[227,227,3])
    crop_top = lambda img: tf.image.random_crop(img[:img_shape[0]//2,:,:], size=[227,227,3])
    ...
    img = tf.image.random_flip_left_right(img)
    img = tf.stack([crop_left(img), crop_top(img),...], axis=0])
    label = tf.reshape(label, [1,1]) #size: (,) -> (1,1)
    label = tf.tile(label, [5, 1]) #size: (1,1) -> (5,1)
    return img, label
dt = parsed_dataset.map(tile_crop) #size: ((5,height,width,channels), (5, 1))
dt = dt.unbatch() #size: ((height,width,channels), (1))

然后您可以随意使用随机播放/重复/批处理/预取。确保每张裁剪的图像都具有相同的尺寸。

【讨论】:

以上是关于如何使用 Tensorflow 2.0 数据集在训练时执行 10 次裁剪图像增强的主要内容,如果未能解决你的问题,请参考以下文章

用 Python 生成的 Tensorflow 数据集在 Tensorflow Java API(标签图像)中有不同的读数

如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强

如何在 TensorFlow 2.0 中使用 Dataset.window() 方法创建的窗口?

如何更改 gpt-2 代码以使用 Tensorflow 2.0?

如何保存使用Tensorflow 1.xx中的.meta检查点模型作为部分的Tensorflow 2.0模型?

如何修改 Tensorflow 2.0 中的 epoch 数?