model.fit_generator 中的 steps_per_epoch 实际上在做啥?
Posted
技术标签:
【中文标题】model.fit_generator 中的 steps_per_epoch 实际上在做啥?【英文标题】:What is steps_per_epoch in model.fit_generator actually doing?model.fit_generator 中的 steps_per_epoch 实际上在做什么? 【发布时间】:2020-12-05 07:18:26 【问题描述】:在阅读了关于model.fit_generator
方法中steps_per_epoch
所需参数的Keras 文档后,我对它的理解是:
如果数据集包含“N”个样本,并且生成器函数(传递给 Keras)为每次调用返回“B = batch_size”样本数(这里,我认为调用是生成器函数的单个产量)和由于 steps_per_epoch = ceil(N/B) 生成器被调用steps_per_epoch
次,因此整个数据集在一个 epoch 后通过模型,并且每个 epoch 都重复相同的过程,直到训练完成。
为了测试我的理解是否正确,我实现了以下
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
index = 0
def get_values(inputs, targets):
i = 0
while True:
yield inputs[i], targets[i]
i += 1
if i >= len(inputs):
i = 0
def get_batch(inputs, targets, batch_size=2):
global index
batch_X = []
batch_Y = []
for inp, targ in get_values(inputs, targets):
batch_X.append(inp)
batch_Y.append(targ)
if len(batch_X) >= batch_size:
yield np.array(batch_X), np.array(batch_Y)
index += 1
batch_X = []
batch_Y = []
data = list(range(10))
labels = [2*val for val in range(10)]
model = Sequential([
Dense(16, activation='relu', input_shape=(1, )),
Dense(1)
])
model.compile(optimizer='rmsprop', loss='mean_squared_error')
model.fit_generator(get_batch(data, labels, batch_size=2), steps_per_epoch=5, epochs=1, verbose=False)
print(index) # Should Print 5 but it prints 15
这个程序并不难理解...
但根据我的解释,它应该打印 5,但它打印 15。我对 steps_per_epoch
的解释错了吗?
如果是,请给我正确的解释steps_per_epoch
PS。我是 Tensorflow 和 Keras 的新手,提前致谢。
【问题讨论】:
【参考方案1】:没有检查您的代码,但您的原始解释是正确的。实际上,根据位于here 的文档,您可以省略每个时期的步骤,model.fit 会将数据集的长度 (N) 除以批量大小来确定步骤。我确实复制并运行了您的代码。猜猜它把索引打印为 5 是什么。我能想到的唯一可能不同的是导入。
【讨论】:
感谢您向我保证我的解释是正确的,而且您确定它为您打印了 5?对我来说,它打印 15 ...我认为 Keras 会事先获取额外的数据,以便对下一个时代有所帮助...@Gerry P 你怎么说? 肯定打印了 5 个以上是关于model.fit_generator 中的 steps_per_epoch 实际上在做啥?的主要内容,如果未能解决你的问题,请参考以下文章
使用Keras model.fit_generator生成器
keras中model.compile的参数'weighted_metrics'和model.fit_generator的参数'class_weight'之间的区别?
使用 keras.utils.Sequence 和 keras.model.fit_generator 时出现 KeyError。