keras ImageDataGenerator 插值二进制掩码

Posted

技术标签:

【中文标题】keras ImageDataGenerator 插值二进制掩码【英文标题】:keras ImageDataGenerator interpolates binary mask 【发布时间】:2020-06-16 16:34:21 【问题描述】:

我正在训练一个神经网络来预测小鼠大脑图像上的二进制掩码。为此,我使用 keras 的 ImageDataGenerator 扩充我的数据。

但我已经意识到数据生成器在应用空间变换时正在对数据进行插值。

这对图像来说很好,但我当然不希望我的掩码包含非二进制值。

在应用转换时有什么方法可以选择最近邻插值?我在 keras 文档中没有找到这样的选项。

(左边是原始的二进制掩码,右边是增强的插值掩码)

图片代码:

data_gen_args = dict(rotation_range=90,
                     width_shift_range=30,
                     height_shift_range=30,
                     shear_range=5,
                     zoom_range=0.3,
                     horizontal_flip=True,
                     vertical_flip=True,
                     fill_mode='nearest')
image_datagen = kp.image.ImageDataGenerator(**data_gen_args)
image_generator = image_datagen.flow(image, seed=1)
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(image))
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(image_generator.next()[0]))
plt.axis('off')
plt.savefig('vis/keras_example')

【问题讨论】:

【参考方案1】:

我自己的二进制图像数据也有同样的问题。有几种方法可以解决这个问题。

简单回答:我通过手动将 ImageDataGenerator 的结果转换为二进制解决了这个问题。如果您手动迭代生成器(使用 'next()' 方法或使用 'for' 循环),那么您可以简单地使用 numpy 'where' 方法将非二进制值转换为二进制:

import numpy as np

batch = image_generator.next()
binary_images = np.where(batch>0, 1, 0)  ## or batch>0.5 or any other thresholds

在 ImageDataGenerator 中使用 preprocessing_function 参数

另一种更好的方法是在ImageDataGenerator 中使用preprocessing_function 参数。正如documentation 中所写,可以指定将在数据增强过程之后执行的自定义预处理函数,因此您可以在data_gen_args 中指定此函数,如下所示:

from keras.preprocessing.image import ImageDataGenerator

data_gen_args = dict(rotation_range=90,
                     width_shift_range=30,
                     height_shift_range=30,
                     shear_range=5,
                     zoom_range=0.3,
                     horizontal_flip=True,
                     vertical_flip=True,
                     fill_mode='nearest',
                     preprocessing_function = lambda x: np.where(x>0, 1, 0).astype(x.dtype))

注意:根据我的经验,preprocessing_functionrescale 之前执行,这也可以在 data_gen_args 中指定为 ImageDataGenerator 的参数。这不是您的情况,但如果您需要指定该参数,请记住这一点。

创建自定义生成器

另一种解决方案是编写一个自定义数据生成器,并在其中修改 ImageDataGenerator 的输出。然后使用这个新的生成器来喂model.fit()。像这样的:

batch_size = 64
image_datagen = kp.image.ImageDataGenerator(**data_gen_args)
image_generator = image_datagen.flow(image, batch_size=batch_size, seed=1)
from tensorflow.keras.utils import Sequence
class MyImageDataGenerator(Sequence):
        def __init__(self, data_size, batch_size):
            self.data_size = data_size
            self.batch_size = batch_size
            super(MyImageDataGenerator).__init__()

        def __len__(self):
            return int(np.ceil(self.data_size / float(self.batch_size)))

        def __getitem__(self, idx):    
            augmented_data = image_generator.next()
            binary_images = np.where(augmented_data>0, 1, 0)
            return binary_images

my_image_generator = MyImageDataGenerator(data_size=len(image), batch_size=batch_size)
model.fit(my_image_generator, epochs=50)

上面的数据生成器也是一个简单的数据生成器。如果需要,您可以自定义它并添加标签(如this)或多模式数据等。

【讨论】:

以上是关于keras ImageDataGenerator 插值二进制掩码的主要内容,如果未能解决你的问题,请参考以下文章

使用 ImageDataGenerator 时 Keras 拆分训练测试集

Keras ImageDataGenerator 不处理符号链接文件

Keras ImageDataGenerator 慢

keras图片数据增强ImageDataGenerator

Keras - 如何在不改变纵横比的情况下使用 ImageDataGenerator

使用 ImageDataGenerator 进行 Keras 数据增强(您的输入没有数据)