keras.utils.Sequence使用注意事项
Posted mazinkaiser1991
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了keras.utils.Sequence使用注意事项相关的知识,希望对你有一定的参考价值。
1)在实现自己的DataLoader过程中一般都是继承自keras.utils.Sequence,继承该类必须要实现__len__与__getitem__两个函数。
2)在调用fit_generator进行训练时,如果设置了step_per_epoch参数,则每个epoch训练step_per_epoch个step,每个step有batch_size数据,因此每个epoch共训练step_per_epoch*batch_size的数据。如果没有设置step_per_epoch参数,则每个epoch训练的step个数由__len__决定。
3)在训练过程中step_per_epoch的个数可以大于 ceil(float(数据集图片数量)/batch_size) ,这个数字可以认为是遍历一遍数据集需要的实际step数量,__len__一般也实现为这个数字。在每遍历过一次数据集后(确切的说是调用__len__次),会调用一次on_epoch_end()。
4)__getitem__在调用时会有一个index参数,这个参数的取值范围就是range(__len__)的结果,index参数的值是在这个范围内随机给定的。因为__len__实现的时候使用的是ceil向上取整,因此很有可能最后一个index就无法取到一组满batch数据,因为数据集图片数量能够正好整除batch_size的情况很少。如果没有取到一组满batch数据,此时可以返回None,或者干脆什么都不返回。fit_generator在检查到是None的时候会再调用__getitem__一次。
5)所以这个地方要特别注意一点,图片无论是训练集还是验证集的数量一定不能小于batch_size,因为如果图片数量小于batch_size,则永远不能取到一组满batch,程序就会进入无限循环。另一方面在计算__len__的时候,使用了ceil,那么__len__至少大于等于1,也不存在不进入__getitem__的情况。除非数据集图片数量是0。
以上是关于keras.utils.Sequence使用注意事项的主要内容,如果未能解决你的问题,请参考以下文章
使用 keras.utils.Sequence 和 keras.model.fit_generator 时出现 KeyError。
keras.utils.Sequence:FileSequence生成文件序列流
on_epoch_end() 未在 keras fit_generator() 中调用
:模型训练和预测的三种方法(fit&tf.GradientTape&train_step&tf.data)