无法从 keras 中的检查点模型恢复训练

Posted

技术标签:

【中文标题】无法从 keras 中的检查点模型恢复训练【英文标题】:Unable to resume training from checkpoint model in keras 【发布时间】:2018-08-05 00:31:24 【问题描述】:

当模型在准确性方面超过前一个时期时,我会在每个时期保存模型。但是当我加载模型时,它不会从保存的模型点恢复。代码如下:

filepath = "weights-improvement-epoch:02d-loss:.2f.hdf5"
callbacks = ModelCheckpoint(filepath, monitor='loss', verbose=0, save_best_only=True, save_weights_only=True,
                            mode='min')
model = load_model(current_dir + '\\' + 'weights-improvement-45-0.67.hdf5')

#model = load_model(current_dir + '\\' + 'weights-improvement-83-0.01.hdf5')
for j in range(n_repeats):
    csv_logger = CSVLogger('log' + str(i) + '_' + str(j) + '.csv', append=True, separator=';')
    print('training on cell array size' + str(cell_size_array[i]) + 'repeat of ' + str(j))

    history = model.fit_generator(get_input_output_spect_yeild(param_dict['dat_dir_train'],meanAbs,stdAbs,meanPhase,stdPhase ),
                                  validation_data=get_input_output_spect_yeild(param_dict['dat_dir_validation'],meanAbs,stdAbs,meanPhase,stdPhase),
                                  validation_steps=val_per_ep, steps_per_epoch=step_per_ep, epochs=num_epochs,
                                  verbose=1, callbacks=[csv_logger, callbacks])

【问题讨论】:

【参考方案1】:

在使用 model.fit_generator 之前,您必须使用保存的权重加载模型。 model.load_weights('最佳权重路径')

【讨论】:

以上是关于无法从 keras 中的检查点模型恢复训练的主要内容,如果未能解决你的问题,请参考以下文章

AI - TensorFlow - 示例05:保存和恢复模型

恢复预训练模型的 TensorFlow 检查点文件

无法从 Pytorch-Lightning 中的检查点加载模型

Keras:如何保存模型并继续训练?

基于CNN卷积神经网络的TensorFlow+Keras深度学习的人脸识别

ValueError:检查目标时出错:预期(keras 序列模型层)具有 n 维,但得到的数组具有形状