使用 tf.data.Dataset 时,Model.fit() 方法的 shuffle 如何处理 Batches?
Posted
技术标签:
【中文标题】使用 tf.data.Dataset 时,Model.fit() 方法的 shuffle 如何处理 Batches?【英文标题】:How does Model.fit() method's shuffle deals with Batches when using a tf.data.Dataset? 【发布时间】:2021-01-29 00:44:34 【问题描述】:我正在使用张量流 2。
当使用带有tf.data.Dataset
的Model.fit()
方法时,会忽略参数“batch_size
”。因此,要批量训练我的模型,我必须首先通过调用 tf.data.Dataset.batch(batch_size)
将我的样本数据集更改为批量样本数据集。
然后,在阅读文档后,我并不清楚.fit()
方法将如何在每个时期对我的数据集进行洗牌。
由于我的数据集是批次数据集,它会在批次之间打乱(批次保持不变)?或者它会打乱所有样本,然后将它们重新组合成新批次(这是所需的行为)?
非常感谢您的帮助。
【问题讨论】:
【参考方案1】:使用tf.data.Dataset
API 时,shuffle
参数对fit
函数没有影响。
如果我们阅读documentation(重点是我的):
shuffle:布尔值(是否在每个 epoch 之前对训练数据进行混洗)或 str(用于 'batch')。 当 x 是生成器时忽略此参数。 'batch' 是处理 HDF5 数据限制的特殊选项;它以批量大小的块进行洗牌。当 steps_per_epoch 不是 None 时无效。
这不是很清楚,但我们可以提示使用 tf.data.Dataset
时会忽略 shuffle 参数,因为它的行为类似于生成器。
为了确定,让我们深入研究代码。如果我们查看fit
方法的代码,您会看到数据由一个特殊的类DataHandler
处理。查看这个类的代码,我们看到这是一个处理不同类型数据的适配器类。我们对处理 tf.data.Dataset 的类DatasetAdapter
感兴趣,可以看到这个类没有考虑到shuffle
参数:
def __init__(self,
x,
y=None,
sample_weights=None,
steps=None,
**kwargs):
super(DatasetAdapter, self).__init__(x, y, **kwargs)
# Note that the dataset instance is immutable, its fine to reuse the user
# provided dataset.
self._dataset = x
# The user-provided steps.
self._user_steps = steps
self._validate_args(y, sample_weights, steps)
如果您想打乱数据集,请使用tf.data.Dataset
API 中的shuffle 函数。
【讨论】:
感谢您的解释。更进一步,如果我想在训练期间对每个时期的数据集进行洗牌。做dataset.shuffle(len_dataset).batch(batch_size)
和model.fit(num_epochs, ....)
就够了吗?或者我必须打电话给dataset.shuffle(len_dataset).batch(batch_size).repeat()
和model.fit(steps_per_epoch=len_dataset//batch_size, ...)
。还是等价的?
@Matt 我很确定调用一次 shuffle 应该在每个 epoch 之间重新调整数据集。您可以通过一个玩具示例来验证这一点。以上是关于使用 tf.data.Dataset 时,Model.fit() 方法的 shuffle 如何处理 Batches?的主要内容,如果未能解决你的问题,请参考以下文章
如何在 keras 自定义回调中访问 tf.data.Dataset?
如何将 tf.data.Dataset 与 kedro 一起使用?
具有渴望模式的 TF.data.dataset.map(map_func)
如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?