keras中fit_generator()的优势

Posted

技术标签:

【中文标题】keras中fit_generator()的优势【英文标题】:Advantage of fit_generator() in keras 【发布时间】:2017-12-30 06:48:00 【问题描述】:

我想知道 keras 中的 fit_generator() 在内存使用方面是否比使用通常的 fit() 方法和生成器产生的 batch_size 具有任何优势。我见过一些类似的例子:

def generator():
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# some data prep
...
while 1:
    for i in range(1875): # 1875 * 32 = 60000 -> # of training samples
        yield X_train[i*32:(i+1)*32], y_train[i*32:(i+1)*32]

如果我将其传递给fit_generator() 方法,或者只是将所有数据直接传递给fit() 方法并定义一个为32 的batch_size,它对(GPU?)内存有什么不同吗?

【问题讨论】:

【参考方案1】:

是的,当您需要增强数据以获得更好的模型准确性时,差异实际上就出现了。

为了提高效率,它允许使用 CPU 对图像进行实时数据增强。这意味着它可以使用 GPU 进行模型训练并进行更新,同时将增强图像的负载委托给 CPU 并提供要训练的批次。

【讨论】:

但是在 GPU 上,内存使用仍然是一样的,对吗?它总是必须使用 fit() 和 fit_generator() 将完整的 batch_size 加载到 GPU 内存中? 是的,你是对的。我认为如果你必须引入增强功能并使用 keras ImageDataGenerator 作为生成器,它提供了许多功能,如旋转、缩放、剪切等,它们将从现有的训练数据,并且不会预先提供,并且可以并行处理这种额外的处理。看看 keras.io/preprocessing/image 的 ImageDataGenerator 感谢您的帮助。我希望有一种方法可以让我只将一批中的一部分加载到模型中,以节省 GPU 内存。但似乎你总是必须至少提供整批。例如。如果我的 batch_size 为 10,我想有一种方法只提供 5 个样本,然后在完成反向传播之前分别提供 5 个样本。也许这不能工作有一些限制。

以上是关于keras中fit_generator()的优势的主要内容,如果未能解决你的问题,请参考以下文章

Keras:网络不使用 fit_generator() 进行训练

on_epoch_end() 未在 keras fit_generator() 中调用

使用 Keras 和 fit_generator 的 TensorBoard 分布和直方图

如何在Keras中使用fit_generator()来加权? [关闭]

history=model.fit_generator() 为啥 keras 历史是空的?

keras:为 fit_generator 使用 ImageDataGenerator 和 KFold 的问题