如何在 tf.data.Dataset 对象上使用序列/生成器将部分数据放入内存?
Posted
技术标签:
【中文标题】如何在 tf.data.Dataset 对象上使用序列/生成器将部分数据放入内存?【英文标题】:How to use sequence/generator on tf.data.Dataset object to fit partial data into memory? 【发布时间】:2020-11-18 06:14:00 【问题描述】:我正在 Google Colab 上使用 Keras 进行图像分类。我使用 tf.keras.preprocessing.image_dataset_from_directory() 函数 (https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory) 加载图像,该函数返回一个 tf.data.Dataset 对象:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=1234,
image_size=(img_height, img_width),
batch_size=batch_size,
label_mode="categorical")
我发现当数据包含数千张图像时,model.fit() 将在训练多个批次后使用所有内存(我使用的是 Google Colab,并且可以看到 RAM 使用量在第一个 epoch 期间增长)。 然后我尝试使用 Keras Sequence,这是将部分数据加载到 RAM 中的建议解决方案 (https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence):
class DatasetGenerator(tf.keras.utils.Sequence):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return tf.data.experimental.cardinality(self.dataset).numpy()
def __getitem__(self, idx):
return list(self.dataset.as_numpy_iterator())[idx]
我用以下方法训练模型:
history = model.fit(DatasetGenerator(train_ds), ...)
问题是 getitem() 必须返回一批带索引的数据。但是,我使用的 list() 函数必须将整个数据集放入 RAM,因此在 DatasetGenerator 对象实例化时会达到内存限制(tf.data.Dataset 对象不支持使用 [] 进行索引)。
我的问题:
-
有没有办法实现 getitem()(从数据集对象中获取特定批次)而不将整个对象放入内存?
如果第 1 项不可行,是否有任何解决方法?
提前致谢!
【问题讨论】:
【参考方案1】:我了解到您担心将完整的数据集保存在内存中。
不用担心,tf.data.Dataset
API 非常高效,它不会将您的完整数据集加载到内存中。
在内部,它只是创建一系列函数,当使用model.fit()
调用时,它只会加载内存中的批次,而不是完整的数据集。
您可以在link 中阅读更多内容,我正在粘贴文档中的重要部分。
tf.data.Dataset API 支持编写描述性和高效 输入管道。数据集的使用遵循一个共同的模式:
从您的输入数据创建一个源数据集。应用数据集 对数据进行预处理的转换。遍历数据集并 处理元素。迭代以流的方式发生,所以 完整的数据集不需要放入内存。
从最后一行可以了解到tf.data.Dataset
API 不会将完整的数据集加载到内存中,而是一次加载一批。
您必须执行以下操作才能批量创建数据集。
train_ds.batch(32)
这将创建大小为32
的批次。您还可以使用预取来准备一批用于训练的批次。这消除了模型在训练一个批次并等待另一批次后空闲的瓶颈。
train_ds.batch(32).prefetch(1)
您还可以使用cache
API 使您的数据管道更快。它将缓存您的数据集并使训练速度更快。
train_ds.batch(32).prefetch(1).cache()
简而言之,如果您担心将整个数据集加载到内存中,则不需要generator
,tf.data.Dataset
API 会处理它。
我希望我的回答能找到你。
【讨论】:
感谢您的回复!在尝试生成器机制之前,我做了与您建议的完全相同的(批处理+预取+缓存)。但是,感谢您对 cache() 的提醒。我发现问题出在 cache() 上,它似乎阻止了经过训练的批次从 RAM 中移出,至少在我使用的 Google Colab 环境中是这样。所以有效的方法不是使用cache(),只是:train_ds = train_ds.prefetch(1)
我很高兴它对你有用。我很高兴得知在 google colab 模式下缓存可能会导致问题。如果我的回答对你有帮助,我会要求你投票赞成,因为以后它也可能对其他人有帮助。
我已经赞成您的回答,但它并没有出现在我身边。我的名声还不够~我是 *** 的新手。以上是关于如何在 tf.data.Dataset 对象上使用序列/生成器将部分数据放入内存?的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow:如何查找 tf.data.Dataset API 对象的大小
如何在 tensorflow tf.data.Dataset 中使用 cv2 图像增强功能?
如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?
如何将 tf.data.Dataset 与 kedro 一起使用?