keras中fit_generator()的优势
Posted
技术标签:
【中文标题】keras中fit_generator()的优势【英文标题】:Advantage of fit_generator() in keras 【发布时间】:2017-12-30 06:48:00 【问题描述】:我想知道 keras 中的 fit_generator()
在内存使用方面是否比使用通常的 fit()
方法和生成器产生的 batch_size
具有任何优势。我见过一些类似的例子:
def generator():
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# some data prep
...
while 1:
for i in range(1875): # 1875 * 32 = 60000 -> # of training samples
yield X_train[i*32:(i+1)*32], y_train[i*32:(i+1)*32]
如果我将其传递给fit_generator()
方法,或者只是将所有数据直接传递给fit()
方法并定义一个为32 的batch_size
,它对(GPU?)内存有什么不同吗?
【问题讨论】:
【参考方案1】:是的,当您需要增强数据以获得更好的模型准确性时,差异实际上就出现了。
为了提高效率,它允许使用 CPU 对图像进行实时数据增强。这意味着它可以使用 GPU 进行模型训练并进行更新,同时将增强图像的负载委托给 CPU 并提供要训练的批次。
【讨论】:
但是在 GPU 上,内存使用仍然是一样的,对吗?它总是必须使用 fit() 和 fit_generator() 将完整的 batch_size 加载到 GPU 内存中? 是的,你是对的。我认为如果你必须引入增强功能并使用 keras ImageDataGenerator 作为生成器,它提供了许多功能,如旋转、缩放、剪切等,它们将从现有的训练数据,并且不会预先提供,并且可以并行处理这种额外的处理。看看 keras.io/preprocessing/image 的 ImageDataGenerator 感谢您的帮助。我希望有一种方法可以让我只将一批中的一部分加载到模型中,以节省 GPU 内存。但似乎你总是必须至少提供整批。例如。如果我的 batch_size 为 10,我想有一种方法只提供 5 个样本,然后在完成反向传播之前分别提供 5 个样本。也许这不能工作有一些限制。以上是关于keras中fit_generator()的优势的主要内容,如果未能解决你的问题,请参考以下文章
Keras:网络不使用 fit_generator() 进行训练
on_epoch_end() 未在 keras fit_generator() 中调用
使用 Keras 和 fit_generator 的 TensorBoard 分布和直方图
如何在Keras中使用fit_generator()来加权? [关闭]