ImageDataGenerator:如何将第 4 维添加到 numpy 数组?

Posted

技术标签:

【中文标题】ImageDataGenerator:如何将第 4 维添加到 numpy 数组?【英文标题】:ImageDataGenerator: how to add the 4th dimension to a numpy array? 【发布时间】:2019-09-14 00:43:57 【问题描述】:

我有以下代码用 opencv 读取图像并显示它:

import cv2, matplotlib.pyplot as plt
img = cv2.imread('imgs_soccer/soccer_10.jpg',cv2.IMREAD_COLOR)
img = cv2.resize(img, (128, 128))
plt.imshow(img)
plt.show()

我想使用 keras 生成一些随机图像,所以我定义了这个生成器:

image_gen = ImageDataGenerator(rotation_range=15,
                           width_shift_range=0.1,
                           height_shift_range=0.1,
                           shear_range=0.01,
                           zoom_range=[0.9, 1.25],
                           horizontal_flip=True,
                           vertical_flip=False,
                           fill_mode='reflect',
                           data_format='channels_last',
                           brightness_range=[0.5, 1.5])

但是,当我以这种方式使用它时:

image_gen.flow(img)

我收到此错误:

'Input data in `NumpyArrayIterator` should have rank 4. You passed an array with shape', (128, 128, 3))

在我看来很明显:RGB,图像,当然是 3 维的! 我在这里想念什么? 文档说它需要一个 4 维数组,但没有指定 我应该在第 4 维中放入什么

这个 4-dim 数组应该如何制作?我现在有(宽度、高度、通道),这个第 4 维度位于 开始或结束

我对 numpy 也不是很熟悉:如何更改现有的 img 数组以添加第 4 维?

【问题讨论】:

【参考方案1】:

使用np.expand_dims():

import numpy as np
img = np.expand_dims(img, 0)
print(img.shape) # (1, 128, 128, 3)

第一个维度指定图像的数量(在您的情况下为 1 个图像)。

【讨论】:

啊,好吧,理论上我可以使用单个数组将任意数量的图像传递给生成器!谢谢 没错,很高兴为您提供帮助!【参考方案2】:

或者,您可以使用 numpy.newaxisNone 将您的 3D 阵列提升为 4D,如下所示:

img = img[np.newaxis, ...] 

# or use None
img = img[None, ...]

第一个维度通常是batch_size。当您想要充分利用 GPU 等现代硬件时,只要您的张量适合您的 GPU 内存,这将为您提供很大的灵活性。例如,您可以通过沿第一个维度堆叠 64 个图像来传递 64 个图像。在这种情况下,您的 4D 数组的形状将是 (64, width, height, channels)

【讨论】:

以上是关于ImageDataGenerator:如何将第 4 维添加到 numpy 数组?的主要内容,如果未能解决你的问题,请参考以下文章

Keras - 如何在不改变纵横比的情况下使用 ImageDataGenerator

在机器学习中,改组如何与 ImageDataGenerator 一起工作?

如何使用批处理为大型数据集拟合 Keras ImageDataGenerator

如何在使用 ImageDataGenerator 时获得基本事实和相应的分数

如何从大型 .h5 数据集中批量读取数据,使用 ImageDataGenerator 和 model.fit 进行预处理,所有这些都不会耗尽内存?

TensorFlow-windowskeras接口——ImageDataGenerator裁剪