批处理、重复和随机播放对 TensorFlow 数据集有啥作用?
Posted
技术标签:
【中文标题】批处理、重复和随机播放对 TensorFlow 数据集有啥作用?【英文标题】:What does batch, repeat, and shuffle do with TensorFlow Dataset?批处理、重复和随机播放对 TensorFlow 数据集有什么作用? 【发布时间】:2019-04-30 01:38:52 【问题描述】:我目前正在学习 TensorFlow,但在下面的代码 sn-p 中遇到了困惑:
dataset = dataset.shuffle(buffer_size = 10 * batch_size)
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
我知道首先数据集将保存所有数据,但 shuffle()
、repeat()
和 batch()
对数据集做了什么?
请帮我举个例子和解释。
【问题讨论】:
【参考方案1】:tf.Dataset 中的以下方法:
repeat( count=0 )
方法重复数据集count
的次数。
shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)
该方法将数据集中的样本打乱。 buffer_size
是随机化并以tf.Dataset
返回的样本数。
batch(batch_size,drop_remainder=False)
创建数据集的批次,批次大小为batch_size
,这也是批次的长度。
【讨论】:
谢谢。我很困惑tensorflow.keras.preprocessing.timeseries_dataset_from_array()
没有drop_remainder
参数。【参考方案2】:
更新:Here 是一个小型协作笔记本,用于演示此答案。
想象一下,你有一个数据集:[1, 2, 3, 4, 5, 6]
,那么:
ds.shuffle() 的工作原理
dataset.shuffle(buffer_size=3)
将分配一个大小为 3 的缓冲区来选择随机条目。此缓冲区将连接到源数据集。
我们可以这样成像:
Random buffer
|
| Source dataset where all other elements live
| |
↓ ↓
[1,2,3] <= [4,5,6]
假设条目2
取自随机缓冲区。空闲空间由源缓冲区中的下一个元素填充,即4
:
2 <= [1,3,4] <= [5,6]
我们继续阅读直到什么都没有:
1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
ds.repeat() 的工作原理
一旦从数据集中读取所有条目并尝试读取下一个元素,数据集就会抛出错误。
这就是ds.repeat()
发挥作用的地方。它将重新初始化数据集,再次像这样:
[1,2,3] <= [4,5,6]
ds.batch() 会产生什么
ds.batch()
将首先获取 batch_size
条目并从中生成一批。因此,我们的示例数据集的批大小为 3 将产生两个批记录:
[2,1,5]
[3,6,4]
由于我们在批处理之前有一个ds.repeat()
,因此数据的生成将继续。但是由于ds.random()
,元素的顺序会有所不同。应该考虑的是,由于随机缓冲区的大小,6
永远不会出现在第一批中。
【讨论】:
如果我不想因为数据是时间序列而对数据进行洗牌怎么办,我还能在不洗牌的情况下使用重复和批量大小吗? @alily,是的。那将是一个选择。另一种选择是让每个批次记录代表一个单独的时间序列记录。这样你就可以从使用 shuffle() 中受益。 @Seymour:顺序是 ds.shuffle(...).repeat().batch(..)。至少对于 TensorFlow 2.1.0。 为什么不 ds.shuffle(reshuffle_each_iteration=True).batch(...).repeat(...) ? 我不明白为什么 6 永远不会出现在第一批中?为什么不?批次不是随机抽取的吗?因此,例如,第一批可能是 [2, 3, 6]。【参考方案3】:一个显示在 epoch 上循环的示例。运行此脚本后,请注意
dataset_gen1
- shuffle 操作产生更多随机输出(这在运行机器学习实验时可能更有用)
dataset_gen2
- 缺少随机播放操作会按顺序生成元素
此脚本中的其他补充
tf.data.experimental.sample_from_datasets
- 用于组合两个数据集。请注意,这种情况下的 shuffle 操作将创建一个缓冲区,该缓冲区从两个数据集中平均采样。
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"
import tensorflow as tf
if len(tf.config.list_physical_devices('GPU')):
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
class Augmentations:
def __init__(self):
pass
@tf.function
def filter_even(self, x):
if x % 2 == 0:
return False
else:
return True
class Dataset:
def __init__(self, aug, range_min=0, range_max=100):
self.range_min = range_min
self.range_max = range_max
self.aug = aug
def generator(self):
dataset = tf.data.Dataset.from_generator(self._generator
, output_types=(tf.float32), args=())
dataset = dataset.filter(self.aug.filter_even)
return dataset
def _generator(self):
for item in range(self.range_min, self.range_max):
yield(item)
# Can be used when you have multiple datasets that you wish to combine
class ZipDataset:
def __init__(self, datasets):
self.datasets = datasets
self.datasets_generators = []
def generator(self):
for dataset in self.datasets:
self.datasets_generators.append(dataset.generator())
return tf.data.experimental.sample_from_datasets(self.datasets_generators)
if __name__ == "__main__":
aug = Augmentations()
dataset1 = Dataset(aug, 0, 100)
dataset2 = Dataset(aug, 100, 200)
dataset = ZipDataset([dataset1, dataset2])
epochs = 2
shuffle_buffer = 10
batch_size = 4
prefetch_buffer = 5
dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
# dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence
for epoch in range(epochs):
print ('\n ------------------ Epoch: ------------------'.format(epoch))
for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
print (X)
# Do some stuff at end of loop
【讨论】:
以上是关于批处理、重复和随机播放对 TensorFlow 数据集有啥作用?的主要内容,如果未能解决你的问题,请参考以下文章
tensorflow图片预处理,随机亮度,旋转,剪切,翻转。