如何在 Tensorflow 数据集中扩充数据?

Posted

技术标签:

【中文标题】如何在 Tensorflow 数据集中扩充数据?【英文标题】:How to augment data in a Tensorflow Dataset? 【发布时间】:2021-02-10 02:25:47 【问题描述】:

对于一组图像,我很困惑术语数据增强是否意味着转换当前数据集(例如裁剪/翻转/旋转/...),或者它是否意味着通过添加裁剪/来增加数据量翻转/旋转图像到初始数据集。据我了解,从这个question 和这个one,它意味着两者。如果我错了,请纠正我。

所以,使用 Tensorflow 数据集,我想实现第二个:增加数据量。

我正在使用来自 TFDS 的 ImageNet 数据(训练集不可用):

import tensorflow_datasets as tfds
ds = tfds.load('imagenet_a', split='test', as_supervised=True)

我想翻转图像:

def transform(image, label):
    image = tf.image.flip_left_right(image)
    return image, label

如果我将转换直接应用于数据集,效果会很好。但不会增加数据量:

ds = ds.map(transform)

所以,我尝试创建第二个数据集并将两者连接起来:

ds0 = ds.map(transform)
ds = ds.concatenate(ds0)

但我收到以下错误:

TypeError: Two datasets to concatenate have different types (tf.uint8, tf.int64) and (tf.float32, tf.int64)

这是连接两个数据集以增加训练集的方法吗? 或者如何正确地做到这一点? (或如何解决我的错误)

我知道ImageDataGenerator,但它不包含我想要的转换

【问题讨论】:

我不知道,但也许this 会有所帮助。 【参考方案1】:

正如错误清楚地表明,两个数据集应该具有相同的数据类型,您可以使用 tf.cast 实现这一点,但这对于大型数据集来说有点忙。

您还可以使用tf.data.experimental.sample_from_datasets 合并数据集

下面是带有插图的代码。

import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing.image import img_to_array, array_to_img
ds , info = tfds.load('imagenet_a', split='test', as_supervised=True,with_info=True)

原始示例图片:

vis = tfds.visualization.show_examples(ds, info)

我正在拍摄 10 张图像进行测试,并使用 map() 函数随机翻转这 10 张图像以创建新数据集。

ds1 = ds.take(10)
ds2 = ds1.map(lambda image, label: (tf.image.random_flip_left_right(image), label))
#Merging both the datasets

new_ds = tf.data.experimental.sample_from_datasets([ds1,ds2])
print(len(list(new_ds))) # Which returns 20, 10 original plus 10 randomly filpped images. 

f, axarr = plt.subplots(5,4,figsize=(15, 15))

ix = 0
i = 0
count = 0
k = 0

for images, labels in new_ds:
  crop_img = array_to_img(images)
  axarr[i,ix].imshow(crop_img)
  ix=ix+1
  count = count + 1
  if count == 4:
     i = i + 1
     count = 0
     ix = 0

合并数据集:

您可以看到原始图像和随机翻转图像的合并数据。

【讨论】:

以上是关于如何在 Tensorflow 数据集中扩充数据?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Tensorflow Estimator 的 input_fn 中执行数据扩充

如何从 TensorFlow 数据集中提取数据/标签

在 Tensorflow 中为输入和输出保持相同的数据集扩充

如何验证 TensorFlow 数据集中的图像? [复制]

如何在保持相同形状和尺寸的同时获得 tensorflow 数据集中的最大值?

将 pandas 数据帧中的 numpy 数组加载到 tensorflow 数据集中