使用数据增强层在 Tensorflow 2.7.0 上保存模型

Posted

技术标签:

【中文标题】使用数据增强层在 Tensorflow 2.7.0 上保存模型【英文标题】:Saving model on Tensorflow 2.7.0 with data augmentation layer 【发布时间】:2021-12-25 13:33:02 【问题描述】:

尝试使用 Tensorflow 版本 2.7.0 保存具有数据增强层的模型时出现错误。

这是数据增强的代码:

input_shape_rgb = (img_height, img_width, 3)
data_augmentation_rgb = tf.keras.Sequential(
  [ 
    layers.RandomFlip("horizontal"),
    layers.RandomFlip("vertical"),
    layers.RandomRotation(0.5),
    layers.RandomZoom(0.5),
    layers.RandomContrast(0.5),
    RandomColorDistortion(name='random_contrast_brightness/none'),
  ]
)

现在我像这样构建我的模型:

# Build the model
input_shape = (img_height, img_width, 3)

model = Sequential([
  layers.Input(input_shape),
  data_augmentation_rgb,
  layers.Rescaling((1./255)),

  layers.Conv2D(16, kernel_size, padding=padding, activation='relu', strides=1, 
     data_format='channels_last'),
  layers.MaxPooling2D(),
  layers.BatchNormalization(),

  layers.Conv2D(32, kernel_size, padding=padding, activation='relu'), # best 4
  layers.MaxPooling2D(),
  layers.BatchNormalization(),

  layers.Conv2D(64, kernel_size, padding=padding, activation='relu'), # best 3
  layers.MaxPooling2D(),
  layers.BatchNormalization(),

  layers.Conv2D(128, kernel_size, padding=padding, activation='relu'), # best 3
  layers.MaxPooling2D(),
  layers.BatchNormalization(),

  layers.Flatten(),
  layers.Dense(128, activation='relu'), # best 1
  layers.Dropout(0.1),
  layers.Dense(128, activation='relu'), # best 1
  layers.Dropout(0.1),
  layers.Dense(64, activation='relu'), # best 1
  layers.Dropout(0.1),
  layers.Dense(num_classes, activation = 'softmax')
 ])

 model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=metrics)
 model.summary()

然后在训练完成后我就做:

model.save("./")

我收到了这个错误:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-84-87d3f09f8bee> in <module>()
----> 1 model.save("./")


/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in 
 error_handler(*args, **kwargs)
 65     except Exception as e:  # pylint: disable=broad-except
 66       filtered_tb = _process_traceback_frames(e.__traceback__)
 ---> 67       raise e.with_traceback(filtered_tb) from None
 68     finally:
 69       del filtered_tb

 /usr/local/lib/python3.7/dist- 
 packages/tensorflow/python/saved_model/function_serialization.py in 
 serialize_concrete_function(concrete_function, node_ids, coder)
 66   except KeyError:
 67     raise KeyError(
 ---> 68         f"Failed to add concrete function 'concrete_function.name' to 
 object-"
 69         f"based SavedModel as it captures tensor capture!r which is 
 unsupported"
 70         " or not reachable from root. "

 KeyError: "Failed to add concrete function 
 'b'__inference_sequential_46_layer_call_fn_662953'' to object-based SavedModel as it 
 captures tensor <tf.Tensor: shape=(), dtype=resource, value=<Resource Tensor>> which 
 is unsupported or not reachable from root. One reason could be that a stateful 
 object or a variable that the function depends on is not assigned to an attribute of 
 the serialized trackable object (see SaveTest.test_captures_unreachable_variable)."

我通过更改模型的架构检查了出现此错误的原因,我发现原因来自 data_augmentation 层,因为 RandomFlipRandomRotation 以及其他从 layers.experimental.prepocessing.RandomFlip 更改为 layers.RandomFlip ,但仍然出现错误。

【问题讨论】:

【参考方案1】:

当使用model.save 结合默认设置的参数save_format="tf" 时,这似乎是Tensorflow 2.7 中的一个错误。 RandomFlipRandomRotationRandomZoomRandomContrast 层导致了问题。有趣的是,Rescaling 层可以毫无问题地保存。一种解决方法是使用较旧的 Keras H5 格式 model.save("test", save_format='h5') 保存您的模型:

import tensorflow as tf
import numpy as np

class RandomColorDistortion(tf.keras.layers.Layer):
    def __init__(self, contrast_range=[0.5, 1.5], 
                 brightness_delta=[-0.2, 0.2], **kwargs):
        super(RandomColorDistortion, self).__init__(**kwargs)
        self.contrast_range = contrast_range
        self.brightness_delta = brightness_delta
    
    def call(self, images, training=None):
        if not training:
            return images
        contrast = np.random.uniform(
            self.contrast_range[0], self.contrast_range[1])
        brightness = np.random.uniform(
            self.brightness_delta[0], self.brightness_delta[1])
        
        images = tf.image.adjust_contrast(images, contrast)
        images = tf.image.adjust_brightness(images, brightness)
        images = tf.clip_by_value(images, 0, 1)
        return images
    
    def get_config(self):
        config = super(RandomColorDistortion, self).get_config()
        config.update("contrast_range": self.contrast_range, "brightness_delta": self.brightness_delta)
        return config
        
input_shape_rgb = (256, 256, 3)
data_augmentation_rgb = tf.keras.Sequential(
  [ 
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomFlip("vertical"),
    tf.keras.layers.RandomRotation(0.5),
    tf.keras.layers.RandomZoom(0.5),
    tf.keras.layers.RandomContrast(0.5),
    RandomColorDistortion(name='random_contrast_brightness/none'),
  ]
)
input_shape = (256, 256, 3)
padding = 'same'
kernel_size = 3
model = tf.keras.Sequential([
  tf.keras.layers.Input(input_shape),
  data_augmentation_rgb,
  tf.keras.layers.Rescaling((1./255)),
  tf.keras.layers.Conv2D(16, kernel_size, padding=padding, activation='relu', strides=1, 
     data_format='channels_last'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.BatchNormalization(),

  tf.keras.layers.Conv2D(32, kernel_size, padding=padding, activation='relu'), # best 4
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.BatchNormalization(),

  tf.keras.layers.Conv2D(64, kernel_size, padding=padding, activation='relu'), # best 3
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.BatchNormalization(),

  tf.keras.layers.Conv2D(128, kernel_size, padding=padding, activation='relu'), # best 3
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.BatchNormalization(),

  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'), # best 1
  tf.keras.layers.Dropout(0.1),
  tf.keras.layers.Dense(128, activation='relu'), # best 1
  tf.keras.layers.Dropout(0.1),
  tf.keras.layers.Dense(64, activation='relu'), # best 1
  tf.keras.layers.Dropout(0.1),
  tf.keras.layers.Dense(5, activation = 'softmax')
 ])

model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()
model.save("test", save_format='h5')

使用自定义层加载模型将如下所示:

model = tf.keras.models.load_model('test.h5', custom_objects='RandomColorDistortion': RandomColorDistortion)

RandomColorDistortion 是您的自定义层的名称。

【讨论】:

当我尝试使用 h5 格式时,我得到另一个 NotImpleentedError 类型的错误。 不,您遇到的错误与您的原始问题无关。正如 GitHub 中已经提到的,如果您打算稍后保存它,则必须将配置添加到您的自定义层。例如查看此帖子:***.com/questions/62280161/… 或仅查看我的答案中自定义层的配置。 啊,当然,我忘记了这个细节。但我还有一个问题,我正在使用模型保存稍后加载模型并将其转换为 .tflite,当尝试加载时,我收到有关客户层的错误,我将在 github 问题上出错,我们将讨论它。谢谢@AloneTogether 更新了答案,有用吗? 是的,我现在可以加载和转换模型,谢谢@AloneTogether。【参考方案2】:

您还可以将 Keras 和 Tensorflow 降级到 2.6 版。

【讨论】:

这是我实际使用的解决方案,但我想使用最后一个 TF 版本

以上是关于使用数据增强层在 Tensorflow 2.7.0 上保存模型的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow 对象检测 API 数据增强边界框

如何使用 Tensorflow 2.0 数据集在训练时执行 10 次裁剪图像增强

如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强

Tensorflow 对象检测 api:如何使用 imgaug 进行增强?

使用 Tensorflow 进行图像增强,因此所有类都具有完全相同数量的图像

Tensorflow 对象检测 API 的数据增强是不是会产生比原始样本更多的样本?