如何在 TensorFlow 中使用我自己的数据将图像拆分为测试和训练集

Posted

技术标签:

【中文标题】如何在 TensorFlow 中使用我自己的数据将图像拆分为测试和训练集【英文标题】:How to split images into test and train set using my own data in TensorFlow 【发布时间】:2020-05-24 14:54:50 【问题描述】:

我在这里有点困惑...我刚刚花了一个小时阅读有关如何在 TensorFlow 中将我的数据集拆分为测试/训练的信息。我正在按照本教程导入我的图像:https://www.tensorflow.org/tutorials/load_data/images。显然,可以使用 sklearn:model_selection.train_test_split 拆分为训练/测试。

但我的问题是:我何时将数据集拆分为训练/测试。我已经用我的数据集完成了这个(见下文),现在怎么办?我该如何拆分它?在将文件加载为tf.data.Dataset 之前我必须这样做吗?

# determine names of classes
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"])
print(CLASS_NAMES)

# count images
image_count = len(list(data_dir.glob('*/*.png')))
print(image_count)


# load the files as a tf.data.Dataset
list_ds = tf.data.Dataset.list_files(str(cwd + '/train/' + '*/*'))

另外,我的数据结构如下所示。没有 test 文件夹,没有 val 文件夹。我需要从那组火车中抽取 20% 的时间进行测试。

train
 |__ class 1
 |__ class 2
 |__ class 3

【问题讨论】:

【参考方案1】:

你可以使用tf.keras.preprocessing.image.ImageDataGenerator:

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(validation_split=0.2)
train_data_gen = image_generator.flow_from_directory(directory='train',
                                                     subset='training')
val_data_gen = image_generator.flow_from_directory(directory='train',
                                                   subset='validation')

请注意,您可能需要为您的生成器设置其他 data-related parameters。

更新:您可以通过skip()take() 获取数据集的两个切片:

val_data = data.take(val_data_size)
train_data = data.skip(val_data_size)

【讨论】:

知道了!谢谢。但是,如果我使用tf.data 加载我的图像,然后使用Dataset.map 创建图像数据集,标签对呢?我现在有我所有的图像在train_ds = prepare_for_training(labeled_ds) 那你会如何分割它?我正在关注本教程:tensorflow.org/tutorials/load_data/images【参考方案2】:

如果您将所有数据都放在同一个文件夹中,并希望使用tf.data 拆分为验证/测试,请执行以下操作:

list_ds = tf.data.Dataset.list_files(str(cwd + '/train/' + '*/*'))
image_count = len(list(data_dir.glob('*/*.png')))

val_size = int(image_count * 0.2) 
train_set = list_ds.skip(val_size)
val_set = list_ds.take(val_size) 

【讨论】:

以上是关于如何在 TensorFlow 中使用我自己的数据将图像拆分为测试和训练集的主要内容,如果未能解决你的问题,请参考以下文章

在自己的数据集上训练 TensorFlow 对象检测

如何在 iOS 中使用经过 Tensorflow 训练的机器学习模型

在 TensorFlow 图像分类中获取标签

tensorflow faster rann

将本地训练的 TensorFlow 模型导入 Google Colab

如何在 TensorFlow 中计算 CNN 的准确度