如何确保预测时间图像输入与训练时间图像输入在同一范围内?

Posted

技术标签:

【中文标题】如何确保预测时间图像输入与训练时间图像输入在同一范围内?【英文标题】:How to make sure the prediction time image input is in the same range as the training time image input? 【发布时间】:2021-11-26 01:10:13 【问题描述】:

这个问题是关于确保预测时间输入图像与训练期间输入的图像在同一范围内。我知道通常的做法是重复训练期间完成的相同步骤以在预测时间处理图像。但就我而言,我在训练期间在自定义数据生成器中应用了random_trasnform() 函数,在预测期间添加它没有意义。


import cv2
import tensorflow as tf
import seaborn as sns

为了简化我的问题,假设我正在对我在自定义数据生成器中读取的灰度图像进行以下更改。

img_1 是数据生成器的输出,应该是 VGG19 模型的输入。

# using a simple augmenter
augmenter = tf.keras.preprocessing.image.ImageDataGenerator(
    brightness_range=(0.75, 1.25),
    preprocessing_function=tf.keras.applications.vgg19.preprocess_input  # preprocessing function of VGG19
)

# read the image
img = cv2.imread('sphx_glr_plot_camera_001.png')
# add a random trasnform
img_1 = augmenter.random_transform(img)/255

上面的random_tranform()使得灰度值分布如下([0,1]之间):

plt.imshow(img_1); plt.show();
sns.histplot(img_1[:, :, 0].ravel());  # select the 0th layer and ravel because the augmenter stacks 3 layers of the grayscale image to make it an RGB image

现在,我想在预测时间做同样的事情,但是,我不想对图像应用随机变换,所以我只是将输入图像通过preprocessing_function()

# read image
img = cv2.imread('sphx_glr_plot_camera_001.png')
# pass through the preprocessing function
img_2 = tf.keras.applications.vgg19.preprocess_input(img)/255

但我无法像训练期间那样使输入在 [0, 1] 范围内。

plt.imshow(img_2); plt.show();
sns.histplot(img_2[:, :, 0].ravel());

这使得预测完全不正确。如何确保预测时模型的输入经历相同的步骤,以使它们最终具有与训练期间输入的输入相似的分布?我也不想在预测时添加random_transform()

【问题讨论】:

使用的图片:scipy-lectures.org/_images/sphx_glr_plot_camera_001.png 【参考方案1】:

我会建议在您的模型中添加一个per image standardization,这将确保您的训练集和推理中图像的平均值为 0,标准差为 1

【讨论】:

以上是关于如何确保预测时间图像输入与训练时间图像输入在同一范围内?的主要内容,如果未能解决你的问题,请参考以下文章

错误 "IndexError: 如何在Keras中使用训练好的模型预测输入图像?

使用分类输入数据和图像输入数据进行分类

如何在keras tensorflow中将图像作为输入并获取另一个图像作为输出

Dlib 特征数组作为 CNN 和预测的输入

如何在表单的输入类型=“文件”中添加图像并将它们都提交到同一个表单后生成缩略图

如何使用 LSTM 对图像进行时间序列预测?