批处理、重复和随机播放对 TensorFlow 数据集有啥作用?

Posted

技术标签:

【中文标题】批处理、重复和随机播放对 TensorFlow 数据集有啥作用?【英文标题】:What does batch, repeat, and shuffle do with TensorFlow Dataset?批处理、重复和随机播放对 TensorFlow 数据集有什么作用? 【发布时间】:2019-04-30 01:38:52 【问题描述】:

我目前正在学习 TensorFlow,但在下面的代码 sn-p 中遇到了困惑:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

我知道首先数据集将保存所有数据,但 shuffle()repeat()batch() 对数据集做了什么? 请帮我举个例子和解释。

【问题讨论】:

【参考方案1】:

tf.Dataset 中的以下方法:

    repeat( count=0 ) 方法重复数据集count 的次数。 shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) 该方法将数据集中的样本打乱。 buffer_size 是随机化并以tf.Dataset 返回的样本数。 batch(batch_size,drop_remainder=False) 创建数据集的批次,批次大小为batch_size,这也是批次的长度。

【讨论】:

谢谢。我很困惑tensorflow.keras.preprocessing.timeseries_dataset_from_array() 没有drop_remainder 参数。【参考方案2】:

更新:Here 是一个小型协作笔记本,用于演示此答案。

想象一下,你有一个数据集:[1, 2, 3, 4, 5, 6],那么:

ds.shuffle() 的工作原理

dataset.shuffle(buffer_size=3) 将分配一个大小为 3 的缓冲区来选择随机条目。此缓冲区将连接到源数据集。 我们可以这样成像:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

假设条目2 取自随机缓冲区。空闲空间由源缓冲区中的下一个元素填充,即4

2 <= [1,3,4] <= [5,6]

我们继续阅读直到什么都没有:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []

ds.repeat() 的工作原理

一旦从数据集中读取所有条目并尝试读取下一个元素,数据集就会抛出错误。 这就是ds.repeat() 发挥作用的地方。它将重新初始化数据集,再次像这样:

[1,2,3] <= [4,5,6]

ds.batch() 会产生什么

ds.batch() 将首先获取 batch_size 条目并从中生成一批。因此,我们的示例数据集的批大小为 3 将产生两个批记录:

[2,1,5]
[3,6,4]

由于我们在批处理之前有一个ds.repeat(),因此数据的生成将继续。但是由于ds.random(),元素的顺序会有所不同。应该考虑的是,由于随机缓冲区的大小,6 永远不会出现在第一批中。

【讨论】:

如果我不想因为数据是时间序列而对数据进行洗牌怎么办,我还能在不洗牌的情况下使用重复和批量大小吗? @alily,是的。那将是一个选择。另一种选择是让每个批次记录代表一个单独的时间序列记录。这样你就可以从使用 shuffle() 中受益。 @Seymour:顺序是 ds.shuffle(...).repeat().batch(..)。至少对于 TensorFlow 2.1.0。 为什么不 ds.shuffle(reshuffle_each_iteration=True).batch(...).repeat(...) ? 我不明白为什么 6 永远不会出现在第一批中?为什么不?批次不是随机抽取的吗?因此,例如,第一批可能是 [2, 3, 6]。【参考方案3】:

一个显示在 epoch 上循环的示例。运行此脚本后,请注意

dataset_gen1 - shuffle 操作产生更多随机输出(这在运行机器学习实验时可能更有用dataset_gen2 - 缺少随机播放操作会按顺序生成元素

此脚本中的其他补充

tf.data.experimental.sample_from_datasets - 用于组合两个数据集。请注意,这种情况下的 shuffle 操作将创建一个缓冲区,该缓冲区从两个数据集中平均采样。
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"

import tensorflow as tf
if len(tf.config.list_physical_devices('GPU')):
    tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)

class Augmentations:

    def __init__(self):
        pass

    @tf.function
    def filter_even(self, x):
        if x % 2 == 0:
            return False
        else:
            return True

class Dataset:

    def __init__(self, aug, range_min=0, range_max=100):
        self.range_min = range_min
        self.range_max = range_max
        self.aug = aug

    def generator(self):
        dataset = tf.data.Dataset.from_generator(self._generator
                        , output_types=(tf.float32), args=())

        dataset = dataset.filter(self.aug.filter_even)

        return dataset
    
    def _generator(self):
        for item in range(self.range_min, self.range_max):
            yield(item)

# Can be used when you have multiple datasets that you wish to combine
class ZipDataset:

    def __init__(self, datasets):
        self.datasets = datasets
        self.datasets_generators = []
    
    def generator(self):
        for dataset in self.datasets:
            self.datasets_generators.append(dataset.generator())
        return tf.data.experimental.sample_from_datasets(self.datasets_generators)

if __name__ == "__main__":
    aug = Augmentations()
    dataset1 = Dataset(aug, 0, 100)
    dataset2 = Dataset(aug, 100, 200)
    dataset = ZipDataset([dataset1, dataset2])

    epochs = 2
    shuffle_buffer = 10
    batch_size = 4
    prefetch_buffer = 5

    dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
    # dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence 

    for epoch in range(epochs):
        print ('\n ------------------ Epoch:  ------------------'.format(epoch))
        for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
            print (X)
        
        # Do some stuff at end of loop

【讨论】:

以上是关于批处理、重复和随机播放对 TensorFlow 数据集有啥作用?的主要内容,如果未能解决你的问题,请参考以下文章

如何随机播放列表[重复]

tensorflow图片预处理,随机亮度,旋转,剪切,翻转。

tensorflow图片预处理,随机亮度,旋转,剪切,翻转。

使用TensorFlow对图像进行随机旋转的实现示例

使用TensorFlow对图像进行随机旋转的实现示例

HTML 音频 - 通过单击 [重复] 随机播放