Keras ImageDataGenerator 慢

Posted

技术标签:

【中文标题】Keras ImageDataGenerator 慢【英文标题】:Keras ImageDataGenerator Slow 【发布时间】:2017-04-25 14:33:22 【问题描述】:

我正在寻找在 Keras 中对大于内存的数据进行训练的最佳方法,目前我注意到普通的 ImageDataGenerator 往往比我希望的要慢。

我在 Kaggle cat's vs dogs 数据集(25000 张图像)上训练了两个网络:

1) 这种方法正是来自:http://www.pyimagesearch.com/2016/09/26/a-simple-neural-network-with-python-and-keras/的代码

2) 与 (1) 相同,但使用 ImageDataGenerator 而不是将数据加载到内存中

注意:对于下文,“预处理”是指调整大小、缩放、展平

我在我的 gtx970 上找到以下内容:

对于网络 1,每个 epoch 大约需要 0 秒。

对于网络 2,如果在数据生成器中完成预处理,则每个 epoch 大约需要 36 秒。

对于网络 2,如果预处理在数据生成器之外的第一遍中完成,则每个 epoch 大约需要 13 秒。

这可能是 ImageDataGenerator 的速度限制吗(13 秒似乎是磁盘和内存之间通常 10-100 倍的差异......)?使用 Keras 时,是否有更适合对大于内存的数据进行训练的方法/机制? 例如也许有办法让 Keras 中的 ImageDataGenerator 在第一个 epoch 之后保存其处理过的图像?

谢谢!

【问题讨论】:

虽然有点老了,但这篇文章是相关的:Slow image data generator。这些帖子表明 Keras(至少在过去的某个时间点)在可以使用单个转换时应用了多个顺序转换。 看这个:github.com/stratospark/keras-multiprocess-image-data-generator/… 【参考方案1】:

在我的理解中,问题在于增强图像在模型的训练周期中只使用一次,甚至在多个 epoch 中都不会使用。因此,在 CPU 苦苦挣扎时,这是对 GPU 周期的巨大浪费。 我找到了以下解决方案:

    我在 RAM 中生成尽可能多的扩充 我使用它们在 10 到 30 个 epoch 的框架内进行训练,无论需要什么才能获得明显的收敛 之后,我生成了一批新的增强图像(通过实现 on_epoch_end)并继续处理。

这种方法在大多数情况下都会让 GPU 保持忙碌,同时还能从数据增强中受益。我使用自定义序列子类来生成增强并同时修复类不平衡。

编辑:添加一些代码来阐明这个想法

from pyutilz.string import read_config_file
from tqdm.notebook import tqdm
from gc import collect
import numpy as np
import tensorflow
import random
import cv2

class StoppingFromFile(tensorflow.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if read_config_file('control.ini','ML','stop',globals()):        
            if stop is not None:        
                if stop==True or stop=='True':
                    logging.warning(f'Model should be stopped according to the control fole')
                    self.model.stop_training = True

class AugmentedBalancedSequence(tensorflow.keras.utils.Sequence):
    def __init__(self, images_and_classes:dict,input_size:tuple,class_sizes:list, augmentations_fn:object, preprocessing_fn:object, batch_size:int=10,
                 num_class_samples=100, frame_length:int=5, aug_p:float=0.1,aug_pipe_p:float=0.2,is_validation:bool=False,
                disk_saving_prob:float=.01,disk_example_nfiles:int=50):
        """
            From a dict of file paths grouped by class label, creates each N epochs augmented balanced training set.
            If current class is too scarce, ensures that current frame has no duplicate final images.
            If it's rich enough, ensures that current frame has no duplicate base images.
        
        """
        logging.info(f'Got len(images_and_classes) classes.')
        self.disk_example_nfiles=disk_example_nfiles;self.disk_saving_prob=disk_saving_prob;self.cur_example_file=0
        
        self.images_and_classes=images_and_classes        
        self.num_class_samples=num_class_samples
        self.augmentations_fn=augmentations_fn
        self.preprocessing_fn=preprocessing_fn
        
        self.is_validation=is_validation
        self.frame_length=frame_length                    
        self.batch_size = batch_size      
        self.class_sizes=class_sizes
        self.input_size=input_size        
        self.aug_pipe_p=aug_pipe_p
        self.aug_p=aug_p        
        self.images=None
        self.epoch = 0
        #print(f'got frame_length=self.frame_length')
        self._generate_data()
        

    def __len__(self):
        return int(np.ceil(len(self.images)/ float(self.batch_size)))

    def __getitem__(self, idx):
        a=idx * self.batch_size;b=a+self.batch_size
        return self.images[a:b],self.labels[a:b]
    
    def on_epoch_end(self):
        import ast
        self.epoch += 1    
        mydict=

        import pathlib
        fname='control.json'
        p = pathlib.Path(fname)
        if p.is_file():
            try:
                with open (fname) as f:
                    mydict=json.load(f)
                for var,val in mydict.items():
                    if hasattr(self,var):
                        converted = val #ast.literal_eval(val)
                        if converted is not None:
                            if getattr(self, var)!=converted:
                                setattr(self, var, converted)                                        
                                print(f'var became val')
            except Exception as e:
                logging.error(str(e))
        if self.epoch % self.frame_length == 0:
            #print('generating data...')
            self._generate_data()
            
    def _add_sample(self,image,label):
        from random import random
        idx=self.indices[self.img_sent]
        
        if self.disk_saving_prob>0:
            if random()<self.disk_saving_prob:
                self.cur_example_file+=1
                if self.cur_example_file>self.disk_example_nfiles:
                    self.cur_example_file=1
                Path(r'example_images/').mkdir(parents=True, exist_ok=True)
                cv2.imwrite(f'example_images/testself.cur_example_file.jpg',cv2.cvtColor(image,cv2.COLOR_RGB2BGR))
        
        if self.preprocessing_fn: 
            self.images[idx]=self.preprocessing_fn(image)
        else:
            self.images[idx]=image
        
        self.labels[idx]=label
        self.img_sent+=1        
        
    def _generate_data(self):
        logging.info('Generating new set of augmented data...')
        
        collect()
        #del self.images
        #del self.labels        
        #collect()
        
        if self.num_class_samples:
            expected_length=len(self.images_and_classes)*self.num_class_samples
        else:
            expected_length=sum(self.class_sizes.values())        
            
        if self.images is None:
            self.images=np.empty((expected_length,)+(self.input_size[1],)+(self.input_size[0],)+(3,))
            self.labels=np.empty((expected_length),np.int32)
        
        self.indices=np.random.choice(expected_length, expected_length, replace=False)
        self.img_sent=0
        
        
        collect()
        
        relaxed_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=self.aug_pipe_p)
        maxed_out_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=1.0)
        
        #for each class
        x,y=[],[]
        nartificial=0
        for label,images in tqdm(self.images_and_classes.items()):
            if self.num_class_samples is None:
                #Just all native samples without augmentations
                for image in images:
                    self._add_sample(image,label)                        
            else:
                #if there are enough native samples
                if len(images)>=self.num_class_samples:
                    #randomly select samples of this class which will participate in this frame of epochs                
                    indices=np.random.choice(len(images), self.num_class_samples, replace=False)
                    #apply albumentations pipeline to selected samples

                    for idx in indices:
                        if not self.is_validation:
                            self._add_sample(relaxed_augmentation_pipeline(image=images[idx])['image'],label)
                        else:
                            self._add_sample(images[idx],label)
                                                    
                else:
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    # Randomly pick next image from existing. try applying augmentation pipeline (with maxed out probability) till we get num_class_samples DIFFERENT images
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    hashes=set()
                    norig=0
                    while len(hashes)<self.num_class_samples:
                        if self.is_validation and norig<len(images):
                            #just include all originals first
                            image=images[norig]
                        else:
                            image=maxed_out_augmentation_pipeline(image=random.choice(images))['image']                                                      
                        next_hash=np.sum(image)
                        if next_hash not in hashes or (self.is_validation and norig<=len(images)):                        
                            
                            #print(f'Adding orig norig out of self.num_class_samples, hashes=hashes')
                            
                            self._add_sample(image,label)
                            if next_hash in hashes:
                                norig+=1
                                hashes.add(norig)
                            else:
                                hashes.add(next_hash)
                                nartificial+=1  
                                
        
        #self.images=self.images[indices];self.labels=self.labels[indices]                              
        
        logging.info(f'Generated self.img_sent samples (nartificial artificial)')

一旦我加载了图像和类,

train_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_train,
                          input_size=INPUT_SIZE,class_sizes=class_sizes_train,num_class_samples=UPSCALE_SAMPLES,
    augmentations_fn=get_albumentations_pipeline,aug_p=AUG_P,aug_pipe_p=AUG_PIPE_P,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,disk_saving_prob=0.05)

val_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_val,
                                        input_size=INPUT_SIZE,class_sizes=class_sizes_val,num_class_samples=None,
    augmentations_fn=get_albumentations_pipeline,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,is_validation=True)

在模型实例化后,我做

model.fit(train_datagen,epochs=600,verbose=1,
          validation_data=(val_datagen.images,val_datagen.labels),validation_batch_size=BATCH_SIZE,
          callbacks=[checkpointer,StoppingFromFile()],validation_freq=1)

【讨论】:

不错的解决方案,你有代码可以分享吗? 谢谢,添加了一些,希望对您有所帮助或至少提供有用的想法。【参考方案2】:

我想你可能已经解决了这个问题,但是……

Keras 图像预处理可以选择通过在flow()flow_from_directory() 函数中设置save_to_dir 参数来保存结果:

https://keras.io/preprocessing/image/

【讨论】:

(用于可视化您正在做的事情)如在 flow>Arguments>save_to_dir

以上是关于Keras ImageDataGenerator 慢的主要内容,如果未能解决你的问题,请参考以下文章

使用 ImageDataGenerator 时 Keras 拆分训练测试集

Keras ImageDataGenerator 不处理符号链接文件

Keras ImageDataGenerator 慢

keras图片数据增强ImageDataGenerator

Keras - 如何在不改变纵横比的情况下使用 ImageDataGenerator

使用 ImageDataGenerator 进行 Keras 数据增强(您的输入没有数据)