TensorFlow 数据集拆分不起作用
Posted
技术标签:
【中文标题】TensorFlow 数据集拆分不起作用【英文标题】:Tensorflow dataset splitting does not work 【发布时间】:2022-01-01 16:25:28 【问题描述】:我最近尝试使用tf.data
API。我创建了一个图像数据集,并且必须拆分为 train/val/test。我使用ds.take
和ds.skip
使用以下方法,但始终正确获取train_ds,并且在test_ds 和val_ds 中没有数据。
DATASET_SIZE = 2000
train_size = int(0.7 * DATASET_SIZE) # 1400
val_size = int(0.15 * DATASET_SIZE) # 300
test_size = int(0.15 * DATASET_SIZE) # 300
train_ds = ds.take(train_size)
val_ds = ds.skip(train_size).take(val_size)
test_ds = ds.skip(train_size+val_size).take(test_size)
当我运行以下内容时:
for image, label in train_ds.take(1):
print("Image shape: ", image.shape)
print("Label: ", label.numpy())
我看到输出为:
Image shape: (32, 400, 400, 3)
Label: [39 23 21 27 28 18 28 30 28 44 34 37 21 39 35 26 48 37 41 30 22 36 46 28
34 38 33 32 36 35 25 24]
但是如果我尝试在上面使用test_ds.take(1)
或val_ds.take(1)
,则没有输出。 test_ds
和 val_ds
似乎是空数据集。另外,当我稍后在我的model.fit()
函数中使用val_ds
时,我看不到val_loss
。
我可以使用其他对我有用的技术,但想了解原因/我在这里做错了什么?
【问题讨论】:
您能否提供有关您如何构建 ds 的更多信息? 最初我使用ds = tf.data.Dataset.from_tensor_slices((filepaths, labels))
创建了包含文件路径(即str)和标签(即int64)的ds。然后使用函数parse_function(filepath, label)
读取所有图像。但是当我使用train_ds = ds.take(1400) val_ds = ds.take(300) test_ds = ds.take(300)
拆分时很奇怪,它给了我所需的样本(但在所有拆分中给出了我不想要的相同样本)。
但我想说我如何创建数据集并不重要。不管怎样,请参阅我有一个数据集ds
,我想拆分它。第一个函数train_ds = ds.take(train_size)
给了我一个完美的train_ds
样本train_size
。但是接下来的两个拆分似乎无法获得val_ds
和test_ds
。 ds.skip
函数有问题还是我做错了?
回答有用吗?
【参考方案1】:
我猜是parse_function(filepath, label)
导致了这个问题,因为这个简单的例子工作得很好:
import tensorflow as tf
import numpy as np
DATASET_SIZE = 2000
data = np.random.random((DATASET_SIZE, 5))
labels = np.random.random((DATASET_SIZE, 1))
ds = tf.data.Dataset.from_tensor_slices((data, labels))
train_size = int(0.7 * DATASET_SIZE) # 1400
val_size = int(0.15 * DATASET_SIZE) # 300
test_size = int(0.15 * DATASET_SIZE) # 300
train_ds = ds.take(train_size)
val_ds = ds.skip(train_size).take(val_size)
test_ds = ds.skip(train_size+val_size).take(test_size)
print(len(train_ds), len(val_ds), len(test_ds))
for x, y in val_ds.take(10):
print(x.shape, y.shape)
1400 300 300
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
(5,) (1,)
可能在parse_function
中找不到或无法读取某些文件路径导致样本较少,但如果没有看到此功能则很难说。
【讨论】:
以上是关于TensorFlow 数据集拆分不起作用的主要内容,如果未能解决你的问题,请参考以下文章
univariate_data 函数在 python tensorflow 教程(熊猫数据框)中不起作用