在 TF Dataset 管道中调用 Keras 标准模型预处理函数

Posted

技术标签:

【中文标题】在 TF Dataset 管道中调用 Keras 标准模型预处理函数【英文标题】:Calling Keras standard model preprocessing functions in TF Dataset pipeline 【发布时间】:2022-01-03 15:06:13 【问题描述】:

我使用 Keras 附带的一些标准 CNN 模型作为我自己模型的基础 - 比如说 VGG16。到目前为止,我习惯于通过 Keras 图像数据生成器调用相应的预处理函数,如下所示:

ImageDataGenerator(preprocessing_function=vgg16.preprocess_input)  # or any other std. model

现在我想改用 TF Dataset,这样我就可以使用它的 from_tensor_slices() 方法,这使得多 GPU 训练更容易。我为这个新管道提出了以下自定义预处理功能:

@tf.function
def load_images(image_path, label):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = vgg16.preprocess_input(image)  # Is this call correct?
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return (image, label)

但我不确定这是否是函数调用的正确顺序,以及在此顺序中调用vgg16.preprocess_input(image) 的正确位置。我可以称这个标准吗?像这样的预处理功能,还是我需要在此之前/之后转换image 数据?

【问题讨论】:

【参考方案1】:

您可以使用路径和标签创建数据集from_tensor_slices(),然后使用map 加载和预处理图像:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy
from PIL import Image

# Create random images
for i in range(3):
  imarray = numpy.random.rand(100,100,3) * 255
  im = Image.fromarray(imarray.astype('uint8'))
  im.save('result_image.jpeg'.format(i))

def load_images(image_path, label):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    
    #preprocess_input --> will convert the input images from RGB to BGR, then will zero-center each color channel with respect to the ImageNet dataset, without scaling
    image = tf.keras.applications.vgg16.preprocess_input(image)
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image /= 255.0 
    return image, label

IMG_SIZE = 64
paths = ['result_image0.jpeg', 'result_image1.jpeg', 'result_image2.jpeg']
labels = [0, 1, 1]

dataset = tf.data.Dataset.from_tensor_slices((paths, labels))
ds = dataset.map(load_images)

image, _ = next(iter(ds.take(1)))
plt.imshow(image)

或者您可以使用tf.keras.applications.vgg16.preprocess_input 作为模型的一部分。例如:

preprocess = tf.keras.applications.vgg16.preprocess_input

some_input = tf.keras.layers.Input((256, 256, 3))
some_output = tf.keras.layers.Lambda(preprocess)(some_input)
model = tf.keras.Model(some_input, some_output)

model(tf.random.normal((2, 256, 256, 3)))

【讨论】:

为什么image /= 250 是必要的——那不是vgg16.preprocess_input(image) 的一部分吗? 根据docs,vgg16.preprocess_input 会将输入图像从 RGB 转换为 BGR,然后将相对于 ImageNet 数据集的每个颜色通道进行零中心化,而无需 缩放 i>. 好的,我明白了,但为什么是 250 而不是 255?这条线怎么样:image = tf.cast(image, tf.float32) / 255.0 抱歉打错了。

以上是关于在 TF Dataset 管道中调用 Keras 标准模型预处理函数的主要内容,如果未能解决你的问题,请参考以下文章

如何在 keras 自定义回调中访问 tf.data.Dataset?

如何将 Keras tf.data 与生成器( flow_from_dataframe )一起使用?形成完美的输入管道

tf.keras 模型 多个输入 tf.data.Dataset

使用 tf.keras.preprocessing.image_dataset_from_directory() 时如何在预测期间获取文件名?

Keras 扩充不适用于 tf.data.Dataset 映射

tf.keras.preprocessing.image_dataset_from_directory() 简介