model.fit_generator 中的 steps_per_epoch 实际上在做啥?

Posted

技术标签:

【中文标题】model.fit_generator 中的 steps_per_epoch 实际上在做啥?【英文标题】:What is steps_per_epoch in model.fit_generator actually doing?model.fit_generator 中的 steps_per_epoch 实际上在做什么? 【发布时间】:2020-12-05 07:18:26 【问题描述】:

在阅读了关于model.fit_generator 方法中steps_per_epoch 所需参数的Keras 文档后,我对它的理解是:

如果数据集包含“N”个样本,并且生成器函数(传递给 Keras)为每次调用返回“B = batch_size”样本数(这里,我认为调用是生成器函数的单个产量)和由于 steps_per_epoch = ceil(N/B) 生成器被调用steps_per_epoch 次,因此整个数据集在一个 epoch 后通过模型,并且每个 epoch 都重复相同的过程,直到训练完成。

为了测试我的理解是否正确,我实现了以下

import numpy as np
from keras.models import Sequential
from keras.layers import Dense

index = 0

def get_values(inputs, targets):
    i = 0
    while True:
        yield inputs[i], targets[i]
        i += 1
        if i >= len(inputs):
            i = 0


def get_batch(inputs, targets, batch_size=2):
    global index
    batch_X = []
    batch_Y = []
    for inp, targ in get_values(inputs, targets):
        batch_X.append(inp)
        batch_Y.append(targ)

        if len(batch_X) >= batch_size:
            yield np.array(batch_X), np.array(batch_Y)
            index += 1
            batch_X = []
            batch_Y = []


data = list(range(10))
labels = [2*val for val in range(10)]

model = Sequential([
    Dense(16, activation='relu', input_shape=(1, )),
    Dense(1)
])

model.compile(optimizer='rmsprop', loss='mean_squared_error')
model.fit_generator(get_batch(data, labels, batch_size=2), steps_per_epoch=5, epochs=1, verbose=False)

print(index) # Should Print 5 but it prints 15

这个程序并不难理解...

但根据我的解释,它应该打印 5,但它打印 15。我对 steps_per_epoch 的解释错了吗? 如果是,请给我正确的解释steps_per_epoch

PS。我是 Tensorflow 和 Keras 的新手,提前致谢。

【问题讨论】:

【参考方案1】:

没有检查您的代码,但您的原始解释是正确的。实际上,根据位于here 的文档,您可以省略每个时期的步骤,model.fit 会将数据集的长度 (N) 除以批量大小来确定步骤。我确实复制并运行了您的代码。猜猜它把索引打印为 5 是什么。我能想到的唯一可能不同的是导入。

【讨论】:

感谢您向我保证我的解释是正确的,而且您确定它为您打印了 5?对我来说,它打印 15 ...我认为 Keras 会事先获取额外的数据,以便对下一个时代有所帮助...@Gerry P 你怎么说? 肯定打印了 5 个

以上是关于model.fit_generator 中的 steps_per_epoch 实际上在做啥?的主要内容,如果未能解决你的问题,请参考以下文章

使用Keras model.fit_generator生成器

keras中model.compile的参数'weighted_metrics'和model.fit_generator的参数'class_weight'之间的区别?

keras fit_generator 并行

使用 keras.utils.Sequence 和 keras.model.fit_generator 时出现 KeyError。

keras 入门整理 如何shuffle,如何使用fit_generator

Keras - 管理历史