如何在 tf 2.1.0 中创建 tf.data.Dataset 的训练、测试和验证拆分

Posted

技术标签:

【中文标题】如何在 tf 2.1.0 中创建 tf.data.Dataset 的训练、测试和验证拆分【英文标题】:how to create train, test & validation split of tf.data.Dataset in tf 2.1.0 【发布时间】:2020-06-27 11:15:16 【问题描述】:

以下代码复制自: https://www.tensorflow.org/tutorials/load_data/images

该代码旨在创建从网络下载的图像数据集,并根据它们的类存储到文件夹中,请参阅上面的链接了解整个上下文!

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())

def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  return parts[-2] == CLASS_NAMES

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])

def process_path(file_path):
  label = get_label(file_path)
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

for image, label in labeled_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())

def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
  # This is a small dataset, only load it once, and keep it in memory.
  # use `.cache(filename)` to cache preprocessing work for datasets that don't
  # fit in memory.
  if cache:
    if isinstance(cache, str):
      ds = ds.cache(cache)
    else:
      ds = ds.cache()

  ds = ds.shuffle(buffer_size=shuffle_buffer_size)

  # Repeat forever
  ds = ds.repeat()

  ds = ds.batch(BATCH_SIZE)

  # `prefetch` lets the dataset fetch batches in the background while the model
  # is training.
  ds = ds.prefetch(buffer_size=AUTOTUNE)

  return ds

train_ds = prepare_for_training(labeled_ds)

我们终于得到了train_ds,它是一个 PreffetchDataset 对象,包含整个图像数据集,标签! 如何将train_ds 拆分为训练、测试和验证集以将其输入模型?

【问题讨论】:

【参考方案1】:

ds.repeat() 调用后,数据集是无限的,拆分无限数据集的效果不是很好。因此,您应该在prepare_training() 调用之前拆分它。像这样:

labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
labeled_ds = labeled_ds.shuffle(10000).batch(BATCH_SIZE)

# Size of dataset
n = sum(1 for _ in labeled_ds)
n_train = int(n * 0.8)
n_valid = int(n * 0.1)
n_test = n - n_train - n_valid

train_ds = labeled_ds.take(n_train)
valid_ds = labeled_ds.skip(n_train).take(n_valid)
test_ds = labeled_ds.skip(n_train + n_valid).take(n_test)

n = sum(1 for _ in labeled_ds) 行遍历数据集一次以获取其大小,然后将其 3 路拆分为 80%/10%/10%。

【讨论】:

以上是关于如何在 tf 2.1.0 中创建 tf.data.Dataset 的训练、测试和验证拆分的主要内容,如果未能解决你的问题,请参考以下文章

如何在 tf.data.Dataset 中输入不同大小的列表列表

如何将 tf.data.Dataset 与 kedro 一起使用?

如何使用提供的需要 tf.Tensor 的 preprocess_input 函数预处理 tf.data.Dataset?

如何在 keras 自定义回调中访问 tf.data.Dataset?

如何使用 tf.data 创建多元时间序列数据集?

如何在 tensorboard 中显示 Tensorflow 2.0 中的 tf.data.Dataset.map 子图?