Tensorflow 数据集预取和缓存选项的正确用途是啥?
Posted
技术标签:
【中文标题】Tensorflow 数据集预取和缓存选项的正确用途是啥?【英文标题】:What is the proper use of Tensorflow dataset prefetch and cache options?Tensorflow 数据集预取和缓存选项的正确用途是什么? 【发布时间】:2020-12-27 00:26:21 【问题描述】:我已经阅读了 TF 页面和一些帖子,以及关于使用 prefetch() 和 cache() 来加速模型输入管道并尝试在我的数据上实现它。 Cache() 按预期为我工作,即在第一个时期从 dist 读取数据,在所有后续时期中,它只是从内存中读取数据。但是我在使用 prefetch() 时遇到了很多困难,我真的不明白何时以及如何使用它。有人可以帮我吗?我真的需要一些帮助。 我的应用程序是这样的:我有一组大型 TFRecord 文件,每个文件都包含一些要在喂网之前处理的原始记录。它们将被混合(不同的样本流),所以我要做的是:
def read_datasets(pattern, numFiles, numEpochs=125, batchSize=1024, take=dataLength):
files = tf.data.Dataset.list_files(pattern)
def _parse(x):
x = tf.data.TFRecordDataset(x, compression_type='GZIP')
return x
np = 4 # half of the number of CPU cores
dataset = files.interleave(_parse, cycle_length=numFiles, block_length=1, num_parallel_calls=np)\
.map(lambda x: parse_tfrecord(x), num_parallel_calls=np)
dataset = dataset.take(take)
dataset = dataset.batch(batchSize)
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=10)
dataset = dataset.repeat(numEpochs)
return dataset
interleave 函数中的 parse_tfrecord(x) 函数是数据应用于模型之前所需的数据预处理,我的猜测是预处理时间与网络的批处理时间相当。我的整个数据集(包括所有输入文件)包含大约 500 批 1024 个样本。我的问题是:
1- 如果我做缓存,我真的需要预取吗?
2- 映射、批处理、缓存、预取和重复的顺序是否正确?
3- Tensorflow 文档说预取的缓冲区大小是指数据集元素,如果它是批处理的,则指批处理的数量。所以在这种情况下我会读 10 批 1024 个例子,对吧?我的问题是,通过更改预取缓冲区大小,我看不到运行时间有任何差异,即使将缓冲区大小设置为 1000 或更大,内存消耗也没有太大变化。
【问题讨论】:
【参考方案1】:我为斯坦福大学的 Andrew Nu 找到了这个很好的解释。 https://cs230.stanford.edu/blog/datapipeline/#best-practices
“当GPU在当前批次上进行前向/反向传播时,我们希望CPU处理下一批数据以便立即准备好。作为计算机最昂贵的部分,我们希望GPU在训练过程中一直被充分使用。我们称之为消费者/生产者重叠,其中消费者是GPU,生产者是CPU。
使用 tf.data,您可以通过在管道末端(批处理后)简单调用 dataset.prefetch(1) 来完成此操作。这将始终预取一批数据并确保始终有一个准备好。
在某些情况下,预取多个批次可能很有用。例如,如果预处理的持续时间变化很大,预取 10 个批次将平均 10 个批次的处理时间,而不是有时等待更长的批次。
举个具体的例子,假设 10% 的批次需要 10 秒来计算,90% 需要 1 秒。如果 GPU 需要 2 秒来训练一个批次,那么通过预取多个批次可以确保我们永远不会等待这些罕见的较长批次。”
我不太确定如何确定每批的处理时间,但这是下一步。如果您的批处理所花费的时间大致相同,那么我相信 prefetch(batch_size=1) 就足够了,因为您的 GPU 不会等待 CPU 完成处理计算量大的批处理。
【讨论】:
【参考方案2】:您能否看看这个*** Answer 以快速了解TensorFlow Dataset 的函数cache()
和prefetch()
。
另外,我发现 Tensorflow Documentation 对优化 tf.Data
Api 的性能非常有帮助。他们已经为各种执行方式指定了基准和执行时间。您还可以分别找到有关数据的序列化和并行化加载和转换及其执行时间的信息。
【讨论】:
以上是关于Tensorflow 数据集预取和缓存选项的正确用途是啥?的主要内容,如果未能解决你的问题,请参考以下文章