使用数据增强层在 Tensorflow 2.7.0 上保存模型
Posted
技术标签:
【中文标题】使用数据增强层在 Tensorflow 2.7.0 上保存模型【英文标题】:Saving model on Tensorflow 2.7.0 with data augmentation layer 【发布时间】:2021-12-25 13:33:02 【问题描述】:尝试使用 Tensorflow 版本 2.7.0 保存具有数据增强层的模型时出现错误。
这是数据增强的代码:
input_shape_rgb = (img_height, img_width, 3)
data_augmentation_rgb = tf.keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomFlip("vertical"),
layers.RandomRotation(0.5),
layers.RandomZoom(0.5),
layers.RandomContrast(0.5),
RandomColorDistortion(name='random_contrast_brightness/none'),
]
)
现在我像这样构建我的模型:
# Build the model
input_shape = (img_height, img_width, 3)
model = Sequential([
layers.Input(input_shape),
data_augmentation_rgb,
layers.Rescaling((1./255)),
layers.Conv2D(16, kernel_size, padding=padding, activation='relu', strides=1,
data_format='channels_last'),
layers.MaxPooling2D(),
layers.BatchNormalization(),
layers.Conv2D(32, kernel_size, padding=padding, activation='relu'), # best 4
layers.MaxPooling2D(),
layers.BatchNormalization(),
layers.Conv2D(64, kernel_size, padding=padding, activation='relu'), # best 3
layers.MaxPooling2D(),
layers.BatchNormalization(),
layers.Conv2D(128, kernel_size, padding=padding, activation='relu'), # best 3
layers.MaxPooling2D(),
layers.BatchNormalization(),
layers.Flatten(),
layers.Dense(128, activation='relu'), # best 1
layers.Dropout(0.1),
layers.Dense(128, activation='relu'), # best 1
layers.Dropout(0.1),
layers.Dense(64, activation='relu'), # best 1
layers.Dropout(0.1),
layers.Dense(num_classes, activation = 'softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=metrics)
model.summary()
然后在训练完成后我就做:
model.save("./")
我收到了这个错误:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-84-87d3f09f8bee> in <module>()
----> 1 model.save("./")
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in
error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.7/dist-
packages/tensorflow/python/saved_model/function_serialization.py in
serialize_concrete_function(concrete_function, node_ids, coder)
66 except KeyError:
67 raise KeyError(
---> 68 f"Failed to add concrete function 'concrete_function.name' to
object-"
69 f"based SavedModel as it captures tensor capture!r which is
unsupported"
70 " or not reachable from root. "
KeyError: "Failed to add concrete function
'b'__inference_sequential_46_layer_call_fn_662953'' to object-based SavedModel as it
captures tensor <tf.Tensor: shape=(), dtype=resource, value=<Resource Tensor>> which
is unsupported or not reachable from root. One reason could be that a stateful
object or a variable that the function depends on is not assigned to an attribute of
the serialized trackable object (see SaveTest.test_captures_unreachable_variable)."
我通过更改模型的架构检查了出现此错误的原因,我发现原因来自 data_augmentation 层,因为 RandomFlip
和 RandomRotation
以及其他从 layers.experimental.prepocessing.RandomFlip
更改为 layers.RandomFlip
,但仍然出现错误。
【问题讨论】:
【参考方案1】:当使用model.save
结合默认设置的参数save_format="tf"
时,这似乎是Tensorflow 2.7 中的一个错误。 RandomFlip
、RandomRotation
、RandomZoom
和 RandomContrast
层导致了问题。有趣的是,Rescaling
层可以毫无问题地保存。一种解决方法是使用较旧的 Keras H5 格式 model.save("test", save_format='h5')
保存您的模型:
import tensorflow as tf
import numpy as np
class RandomColorDistortion(tf.keras.layers.Layer):
def __init__(self, contrast_range=[0.5, 1.5],
brightness_delta=[-0.2, 0.2], **kwargs):
super(RandomColorDistortion, self).__init__(**kwargs)
self.contrast_range = contrast_range
self.brightness_delta = brightness_delta
def call(self, images, training=None):
if not training:
return images
contrast = np.random.uniform(
self.contrast_range[0], self.contrast_range[1])
brightness = np.random.uniform(
self.brightness_delta[0], self.brightness_delta[1])
images = tf.image.adjust_contrast(images, contrast)
images = tf.image.adjust_brightness(images, brightness)
images = tf.clip_by_value(images, 0, 1)
return images
def get_config(self):
config = super(RandomColorDistortion, self).get_config()
config.update("contrast_range": self.contrast_range, "brightness_delta": self.brightness_delta)
return config
input_shape_rgb = (256, 256, 3)
data_augmentation_rgb = tf.keras.Sequential(
[
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomFlip("vertical"),
tf.keras.layers.RandomRotation(0.5),
tf.keras.layers.RandomZoom(0.5),
tf.keras.layers.RandomContrast(0.5),
RandomColorDistortion(name='random_contrast_brightness/none'),
]
)
input_shape = (256, 256, 3)
padding = 'same'
kernel_size = 3
model = tf.keras.Sequential([
tf.keras.layers.Input(input_shape),
data_augmentation_rgb,
tf.keras.layers.Rescaling((1./255)),
tf.keras.layers.Conv2D(16, kernel_size, padding=padding, activation='relu', strides=1,
data_format='channels_last'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(32, kernel_size, padding=padding, activation='relu'), # best 4
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(64, kernel_size, padding=padding, activation='relu'), # best 3
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(128, kernel_size, padding=padding, activation='relu'), # best 3
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'), # best 1
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(128, activation='relu'), # best 1
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'), # best 1
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(5, activation = 'softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()
model.save("test", save_format='h5')
使用自定义层加载模型将如下所示:
model = tf.keras.models.load_model('test.h5', custom_objects='RandomColorDistortion': RandomColorDistortion)
RandomColorDistortion
是您的自定义层的名称。
【讨论】:
当我尝试使用 h5 格式时,我得到另一个 NotImpleentedError 类型的错误。 不,您遇到的错误与您的原始问题无关。正如 GitHub 中已经提到的,如果您打算稍后保存它,则必须将配置添加到您的自定义层。例如查看此帖子:***.com/questions/62280161/… 或仅查看我的答案中自定义层的配置。 啊,当然,我忘记了这个细节。但我还有一个问题,我正在使用模型保存稍后加载模型并将其转换为 .tflite,当尝试加载时,我收到有关客户层的错误,我将在 github 问题上出错,我们将讨论它。谢谢@AloneTogether 更新了答案,有用吗? 是的,我现在可以加载和转换模型,谢谢@AloneTogether。【参考方案2】:您还可以将 Keras 和 Tensorflow 降级到 2.6 版。
【讨论】:
这是我实际使用的解决方案,但我想使用最后一个 TF 版本以上是关于使用数据增强层在 Tensorflow 2.7.0 上保存模型的主要内容,如果未能解决你的问题,请参考以下文章
如何使用 Tensorflow 2.0 数据集在训练时执行 10 次裁剪图像增强
如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强
Tensorflow 对象检测 api:如何使用 imgaug 进行增强?