Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型
Posted szqfreiburger
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型相关的知识,希望对你有一定的参考价值。
最近做Kaggle的图像分类比赛:RSNA Intracranial Hemorrhage Detection (https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/overview)以及阅读Yolov3
源码的时候接触到深度学习训练时一个有趣的技巧,那就是构造生成器generator 并且用Keras 的fit_generator来批量生成数据,释放内存,该方法适合于大规模数据集的训练。一个DataGenerator是keras的Sequence类的继承类,一般要包含__len__,__getitem__, on_epoch_end等方法,例如下面的批量图片数据生成器:
class DataGenerator(keras.utils.Sequence): def __init__(self, list_IDs, labels, batch_size=1, img_size=(512, 512), img_dir, *args, **kwargs): """ self.list_IDs:存放所有需要训练的图片文件名的列表。 self.labels:记录图片标注的分类信息的pandas.DataFrame数据类型,已经预先给定。 self.batch_size:每次批量生成,训练的样本大小。 self.img_size:训练的图片尺寸。 self.img_dir:图片在电脑中存放的路径。 """ self.list_IDs = list_IDs self.labels = labels self.batch_size = batch_size self.img_size = img_size self.img_dir = img_dir self.on_epoch_end() def __len__(self): """ 返回生成器的长度,也就是总共分批生成数据的次数。 """ return int(ceil(len(self.list_IDs) / self.batch_size)) def __getitem__(self, index): """ 该函数返回每次我们需要的经过处理的数据。 """ indices = self.indices[index*self.batch_size:(index+1)*self.batch_size] list_IDs_temp = [self.list_IDs[k] for k in indices] X, Y = self.__data_generation(list_IDs_temp) return X, Y def on_epoch_end(self): """ 该函数将在训练时每一个epoch结束的时候自动执行,在这里是随机打乱索引次序以方便下一batch运行。 """ self.indices = np.arange(len(self.list_IDs)) np.random.shuffle(self.indices) def __data_generation(self, list_IDs_temp): """ 给定文件名,生成数据。 """ X = np.empty((self.batch_size, *self.img_size, 1)) Y = np.empty((self.batch_size, 6), dtype=np.float32) for i, ID in enumerate(list_IDs_temp): X[i,] = mpimg.imread(self.img_dir+ID+".png") Y[i,] = self.labels.loc[ID].values return X, Y
有了这个生成器,我们就可以用fit_generator 方法进行训练,格式套路如下:
model.fit_generator(generator,
steps_per_epoch=...,
epochs=...,
verbose=...,
callbacks=...,
validation_data=...,
validation_steps=...,
validation_freq=...,
class_weight=None=...,
max_queue_size=...
workers=...,
use_multiprocessing=...,
)
除此以外我们还可以搞批量预测:
model.predict_generator()
以上是关于Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型的主要内容,如果未能解决你的问题,请参考以下文章
Keras图像分割实战:数据整理分割自定义数据生成器模型训练