on_epoch_end() 未在 keras fit_generator() 中调用

Posted

技术标签:

【中文标题】on_epoch_end() 未在 keras fit_generator() 中调用【英文标题】:on_epoch_end() not called in keras fit_generator() 【发布时间】:2020-04-25 23:19:03 【问题描述】:

我跟随this tutorial 使用fit_generator() Keras 方法即时生成数据,以训练我的神经网络模型。

我使用keras.utils.Sequence 类创建了一个生成器。对fit_generator() 的调用是:

history = model.fit_generator(generator=EVDSSequence(images_train, TRAIN_BATCH_SIZE, INPUT_IMG_DIR, INPUT_JSON_DIR, SPLIT_CHAR, sizeArray, NCHW, shuffle=True),
                              steps_per_epoch=None, epochs=EPOCHS,
                              validation_data=EVDSSequence(images_valid, VALID_BATCH_SIZE, INPUT_IMG_DIR, INPUT_JSON_DIR, SPLIT_CHAR, sizeArray, NCHW, shuffle=True),
                              validation_steps=None,
                              callbacks=callbacksList, verbose=1,
                              workers=0, max_queue_size=1, use_multiprocessing=False)

steps_per_epochNone,所以每个epoch的步数是通过Keras的__len()__方法计算出来的。

如上链接所述:

这里,on_epoch_end 方法在每个 epoch 的开始和结束时触发一次。如果shuffle 参数设置为True,我们将在每次通过时获得一个新的探索顺序(否则保持线性探索方案)。

我的问题是 on_epoch_end() 方法只在开始时被调用,而不会在每个纪元结束时被调用。 因此,在每个 epoch,批次顺序始终相同。

我尝试在__len__() 方法中使用np.ceil 而不是np.floor,但没有成功。

你知道为什么 on_epoch_end 在每个 epoch 结束时不被调用吗?你能告诉我在每个时期结束(或开始)时调整批次顺序的任何解决方法吗?

非常感谢!

【问题讨论】:

【参考方案1】:

我遇到了同样的问题。我不知道为什么会发生这种情况,但有一种解决方法:在__len__() 内调用on_epoch_end(),因为__len__() 将在每个时期被调用。

【讨论】:

感谢您的回复!无论如何,几天前我已经(临时)以这种方式解决了,但这会打乱所有样本。我最初的目标是只打乱批次提供给网络的顺序,保留单个批次中样品的顺序。 非常感谢,我也遇到了同样的问题,直到我发现我的实验结果很奇怪时才意识到。【参考方案2】:

可能与问题有关:Keras model.fit not calling Sequence.on_epoch_end() #35911

快速解决方法是使用LambdaCallback(请注意,我使用fit 就足够了,因为不推荐使用fit_generator

from tf.keras.callbacks import LambdaCallback

model.fit(generator, callbacks=[LambdaCallback(on_epoch_end=generator.on_epoch_end)])

希望对你有帮助!

【讨论】:

【参考方案3】:

而且我发现当您创建 on_predict_end() callback_lambda 时,它不会在预测结束时调用。顺便说一句,predict() 接受一个 callbacks=list(...) 参数。

此外,您似乎可以像这样测试回调:

(create your 'model' object)
callback_batch_end <- callback_lambda(
    on_batch_end = function(batch, logs) 
        cat("Hello world\n")
    
)
callback_batch_end$on_batch_end(1, "x")
(prints 'Hello world')
callback_predict_end <- callback_lambda(
    on_predict_end = function(logs) 
        cat("Hello world\n")
    
)
callback_predict_end$on_predict_end("x")
(prints nothing)

【讨论】:

以上是关于on_epoch_end() 未在 keras fit_generator() 中调用的主要内容,如果未能解决你的问题,请参考以下文章

在keras中计算微F-1分数

<f:selectItems> 未在 <h:selectManyListbox> 中呈现

如何修复错误:“vreinterpretq_u32_f64”未在此范围内声明 - 在 Android 上使用 Eigen 构建

original_keras_version = f.attrs[‘keras_version‘].decode(‘utf8‘)AttributeError: ‘str‘ object has no(

xgb, lgb, Keras, LR(收藏)

Tensorflow2深度学习基础和tf.keras