TensorFlow 数据集拆分不起作用

Posted

技术标签:

【中文标题】TensorFlow 数据集拆分不起作用【英文标题】:Tensorflow dataset splitting does not work 【发布时间】:2022-01-01 16:25:28 【问题描述】:

我最近尝试使用tf.data API。我创建了一个图像数据集,并且必须拆分为 train/val/test。我使用ds.takeds.skip 使用以下方法,但始终正确获取train_ds,并且在test_ds 和val_ds 中没有数据。

DATASET_SIZE = 2000
train_size = int(0.7 * DATASET_SIZE)  # 1400
val_size   = int(0.15 * DATASET_SIZE) # 300
test_size  = int(0.15 * DATASET_SIZE) # 300
train_ds   = ds.take(train_size)
val_ds     = ds.skip(train_size).take(val_size)
test_ds    = ds.skip(train_size+val_size).take(test_size)

当我运行以下内容时:

for image, label in train_ds.take(1): 
  print("Image shape: ", image.shape)
  print("Label: ", label.numpy())

我看到输出为:

Image shape:  (32, 400, 400, 3)
Label:  [39 23 21 27 28 18 28 30 28 44 34 37 21 39 35 26 48 37 41 30 22 36 46 28
 34 38 33 32 36 35 25 24]

但是如果我尝试在上面使用test_ds.take(1)val_ds.take(1),则没有输出。 test_dsval_ds 似乎是空数据集。另外,当我稍后在我的model.fit() 函数中使用val_ds 时,我看不到val_loss

我可以使用其他对我有用的技术,但想了解原因/我在这里做错了什么?

【问题讨论】:

您能否提供有关您如何构建 ds 的更多信息? 最初我使用ds = tf.data.Dataset.from_tensor_slices((filepaths, labels)) 创建了包含文件路径(即str)和标签(即int64)的ds。然后使用函数parse_function(filepath, label) 读取所有图像。但是当我使用train_ds = ds.take(1400) val_ds = ds.take(300) test_ds = ds.take(300) 拆分时很奇怪,它给了我所需的样本(但在所有拆分中给出了我不想要的相同样本)。 但我想说我如何创建数据集并不重要。不管怎样,请参阅我有一个数据集ds,我想拆分它。第一个函数train_ds = ds.take(train_size) 给了我一个完美的train_ds 样本train_size。但是接下来的两个拆分似乎无法获得val_dstest_dsds.skip 函数有问题还是我做错了? 回答有用吗? 【参考方案1】:

我猜是parse_function(filepath, label) 导致了这个问题,因为这个简单的例子工作得很好:

import tensorflow as tf
import numpy as np

DATASET_SIZE = 2000
data = np.random.random((DATASET_SIZE, 5))
labels = np.random.random((DATASET_SIZE, 1))

ds = tf.data.Dataset.from_tensor_slices((data, labels))
train_size = int(0.7 * DATASET_SIZE)  # 1400
val_size   = int(0.15 * DATASET_SIZE) # 300
test_size  = int(0.15 * DATASET_SIZE) # 300
train_ds   = ds.take(train_size)
val_ds     = ds.skip(train_size).take(val_size)
test_ds    = ds.skip(train_size+val_size).take(test_size)

print(len(train_ds), len(val_ds), len(test_ds))
for x, y  in val_ds.take(10):
  print(x.shape, y.shape)
1400 300 300
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)

可能在parse_function 中找不到或无法读取某些文件路径导致样本较少,但如果没有看到此功能则很难说。

【讨论】:

以上是关于TensorFlow 数据集拆分不起作用的主要内容,如果未能解决你的问题,请参考以下文章

univariate_data 函数在 python tensorflow 教程(熊猫数据框)中不起作用

TensorFlow 不起作用

加密大于 RSA 密钥(模数)大小的数据。将拆分的数据加密成块,似乎不起作用。有任何想法吗?

为啥 random.seed() 在生成数据集时不起作用?

TENSORFLOW.JS 3D 姿势估计不起作用

为啥 webpack 代码拆分对我不起作用?