高效标准化 Numpy 数组中的图像

Posted

技术标签:

【中文标题】高效标准化 Numpy 数组中的图像【英文标题】:Efficiently Standardizing Images in a Numpy Array 【发布时间】:2018-10-18 17:27:24 【问题描述】:

我有一个形状为 (N, H, W, C) 的图像的 numpy 数组,其中 N 是图像的数量,H 是图像高度,W 是图像宽度,C 是 RGB 通道。

我想按通道标准化我的图像,因此对于每个图像,我想按通道减去图像通道的平均值并除以其标准差。

我在一个循环中执行此操作,这很有效,但是效率非常低,并且因为它会复制我的 RAM 太满了。

def standardize(img):
    mean = np.mean(img)
    std = np.std(img)
    img = (img - mean) / std
    return img

for img in rgb_images:
    r_channel = standardize(img[:,:,0])
    g_channel = standardize(img[:,:,1])
    b_channel = standardize(img[:,:,2])
    normalized_image = np.stack([r_channel, g_channel, b_channel], axis=-1)
    standardized_images.append(normalized_image)
standardized_images = np.array(standardized_images)

如何更有效地利用 numpy 的功能?

【问题讨论】:

【参考方案1】:

沿第二个和第三个轴执行 ufunc 缩减(均值、标准差),同时保持维度不变,这有助于稍后在除法步骤中的 broadcasting -

mean = np.mean(rgb_images, axis=(1,2), keepdims=True)
std = np.std(rgb_images, axis=(1,2), keepdims=True)
standardized_images_out = (rgb_images - mean) / std

根据公式重新使用平均值来计算标准偏差,从而进一步提高性能,因此受到 this solution 的启发,就像这样 -

std = np.sqrt(((rgb_images - mean)**2).mean((1,2), keepdims=True))

将归约轴作为参数打包成一个函数,我们会有 -

from __future__ import division

def normalize_meanstd(a, axis=None): 
    # axis param denotes axes along which mean & std reductions are to be performed
    mean = np.mean(a, axis=axis, keepdims=True)
    std = np.sqrt(((a - mean)**2).mean(axis=axis, keepdims=True))
    return (a - mean) / std

standardized_images = normalize_meanstd(rgb_images, axis=(1,2))

【讨论】:

你能解释一下在这种情况下轴参数是如何工作的吗?我看不出这是可能的。并且有必要在以后做减法和除法吗?此外,很好的答案!我明天验证并给你学分。 @Chris 这应该有助于轴 - docs.scipy.org/doc/numpy-1.13.0/reference/ufuncs.html#methods。 keepdims 是必要的,是保留没有。昏暗,根据以后广播的需要。 为了澄清一点,您的操作也会制作副本,对吗?有没有办法就地进行减法和除法运算? @Chris 在最后一步,使用out = 参数为numpy.subtractnumpy.divide 替换那些对应的操作。 所以np.subtract(rgb_images, mean, out=rgb_images) 有效吗?或者写入与用作第一个参数的数组相同的数组是否会导致问题?和rgb_images -= mean相比有什么区别?

以上是关于高效标准化 Numpy 数组中的图像的主要内容,如果未能解决你的问题,请参考以下文章

如何规范化 4D numpy 数组?

python - 如何在python numpy中标准化二维数组的一维? [复制]

numpy.ndarray 如何标准化?

numpy数组-标准化数据

Numpy标准化多暗淡(> = 3)数组

pytorch学习笔记第五篇——训练分类器