我可以在 Keras 中使用 ImageDataGenerator() 和 flow_from_directory() 生成 uint8 标签吗?

Posted

技术标签:

【中文标题】我可以在 Keras 中使用 ImageDataGenerator() 和 flow_from_directory() 生成 uint8 标签吗?【英文标题】:Can I generate uint8 label using ImageDataGenerator() and flow_from_directory() in Keras? 【发布时间】:2019-02-15 18:33:06 【问题描述】:

我正在处理 2D 语义分割任务。

在 Keras API 文档中,这些示例仅显示如何安排数据集以进行图像分类,而不是语义分割。

所以我像这样安排我的图像和标签

SEED = 111
batch_size = 2
image_datagen = ImageDataGenerator(
    horizontal_flip=True,
    zca_epsilon=9,
    # fill_mode='nearest',
)
image_generator = image_datagen.flow_from_directory(
    directory="/xxx/images",
    class_mode=None,
    batch_size=batch_size,
    seed=SEED,
)


def preprocessing_function(image):
    return image.astype(np.uint8)


label_datagen = ImageDataGenerator(
    horizontal_flip=True,
    zca_epsilon=9,
    rescale=1,
    preprocessing_function=preprocessing_function,
    # fill_mode='nearest',
)
label_generator = image_datagen.flow_from_directory(
    directory="/xxx/labels",
    class_mode=None,
    batch_size=batch_size,
    seed=SEED,
)

train_generator = zip(image_generator, label_generator)
print(len(image_generator))
i = 0
for image_batch, label_batch in iter(train_generator):
    print(image_batch.shape, label_batch.shape) # (2, 256, 256, 3) (2, 256, 256, 3)
    print(image_batch.dtype, label_batch.dtype) # float32 float32
    i += 1
    if i == 5:
        break

但是我发现生成的label图片的类型是float32,所以我给label_datagen加了一个preprocessing_function函数只是为了把dtype转换成uint8,但是生成的label images的dtype还是float32,好像preprocessing_function什么都没做。

我该如何解决这个问题?

如何将我的标签数据更改为 uint8?

添加预处理函数来转换标签图像的dtype是一种“常见做法”吗?

感谢您的建议!

【问题讨论】:

为什么要转成uint8?如果你想将它提供给神经网络,那么它应该是浮点数据类型。 我只想将我的(语义分割)标签图像转换为 uint8,因为我认为它更好地计算损失值并节省 RAM @ShouyuChen 你有什么发现吗?我也想知道 【参考方案1】:

我遇到了同样的问题并将生成器包装到另一个中。它有效,但它有点杂乱无章

label_generator = (x.astype(np.uint8) for x in label_generator)
train_generator = zip(image_generator, label_generator)

【讨论】:

以上是关于我可以在 Keras 中使用 ImageDataGenerator() 和 flow_from_directory() 生成 uint8 标签吗?的主要内容,如果未能解决你的问题,请参考以下文章

Keras - 是不是可以在 Tensorboard 中查看模型的权重和偏差

我可以在 Keras 中使用 ImageDataGenerator() 和 flow_from_directory() 生成 uint8 标签吗?

在 keras 中使用批量标准化进行微调

在 keras 回调中使用带有自定义参数的自定义函数

如何在 keras 中进行深度学习中的多标签分类?

我可以在 keras 中训练的课程数量是不是有上限?