Keras 扩充不适用于 tf.data.Dataset 映射
Posted
技术标签:
【中文标题】Keras 扩充不适用于 tf.data.Dataset 映射【英文标题】:Keras augmentation does not work with tf.data.Dataset map 【发布时间】:2022-01-16 18:32:50 【问题描述】:我正在尝试使用预处理函数来处理数据集映射,但出现以下错误(底部的完整堆栈跟踪):
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.
下面是重现该问题的完整 sn-p。我的问题是,为什么在一个用例(仅限裁剪)中它可以工作,而当使用 RandomFlip 时却不能?如何解决这个问题?
import functools
import numpy as np
import tensorflow as tf
def data_gen():
for i in range(10):
x = np.random.random(size=(80, 80, 3)) * 255 # rgb image
x = x.astype('uint8')
y = np.random.random(size=(40, 40, 1)) * 255 # downsized mono image
y = y.astype('uint8')
yield x, y
def preprocess(image, label, cropped_image_size, cropped_label_size, skip_augmentations=False):
x = image
y = label
x_size = cropped_image_size
y_size = cropped_label_size
if not skip_augmentations:
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
y = tf.keras.layers.RandomFlip(mode="horizontal")(y)
x = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(x)
y = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(y)
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
return x, y
print(tf.__version__) # 2.6.0
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype='uint8'),
tf.TensorSpec(shape=(40, 40, 1), dtype='uint8')
))
crop_only_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=True)
train_preprocess_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=False)
# This works
crop_dataset = dataset.map(crop_only_fn)
# This fails: ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable
train_dataset = dataset.map(train_preprocess_fn)
全栈跟踪:
Traceback (most recent call last):
File "./issue_dataaug.py", line 50, in <module>
train_dataset = dataset.map(train_preprocess_fn)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1861, in map
return MapDataset(self, map_func, preserve_cardinality=True)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4985, in __init__
use_legacy_function=use_legacy_function)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4218, in __init__
self._function = fn_factory()
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3151, in get_concrete_function
*args, **kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3116, in _get_concrete_function_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3463, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3308, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 1007, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4195, in wrapped_fn
ret = wrapper_helper(*args)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4125, in wrapper_helper
ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
File "/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
./issue_dataaug.py:25 preprocess *
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/keras/layers/preprocessing/image_preprocessing.py:414 __init__ **
self._rng = make_generator(self.seed)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/keras/layers/preprocessing/image_preprocessing.py:1375 make_generator
return tf.random.Generator.from_non_deterministic_state()
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py:396 from_non_deterministic_state
return cls(state=state, alg=alg)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py:476 __init__
trainable=False)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py:489 _create_variable
return variables.Variable(*args, **kwargs)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:268 __call__
return cls._variable_v2_call(*args, **kwargs)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 _variable_v2_call
shape=shape)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:243 <lambda>
previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py:2675 default_variable_creator_v2
shape=shape)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:270 __call__
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1613 __init__
distribute_strategy=distribute_strategy)
/...//virtualenvs/cvi36/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1695 _init_from_args
raise ValueError("Tensor-typed variable initializers must either be "
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.
【问题讨论】:
使用上面给定的代码,你能确保重现性吗?您提到的错误是预期的,但是给定的代码,我没有得到与您提到的相同的错误。 【参考方案1】:我不太确定这是否与您的问题直接相关,但在 TF
2.7 上,您的代码根本不起作用,因为所有 Keras
增强层都需要 float
值而不是 uint8
。所以,也许尝试像这样投射你的数据:
import functools
import numpy as np
import tensorflow as tf
def data_gen():
for i in range(10):
x = np.random.random(size=(80, 80, 3)) * 255 # rgb image
x = x.astype('uint8')
y = np.random.random(size=(40, 40, 1)) * 255 # downsized mono image
y = y.astype('uint8')
yield x, y
def preprocess(image, label, cropped_image_size, cropped_label_size, skip_augmentations=False):
x = tf.cast(image, dtype=tf.float32)
y = tf.cast(label, dtype=tf.float32)
x_size = cropped_image_size
y_size = cropped_label_size
if not skip_augmentations:
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
y = tf.keras.layers.RandomFlip(mode="horizontal")(y)
x = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(x)
y = tf.keras.layers.RandomRotation(factor=1.0, fill_mode='constant')(y)
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
return tf.cast(x, dtype=tf.uint8), tf.cast(y, dtype=tf.uint8)
print(tf.__version__) # 2.6.0
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype=tf.uint8),
tf.TensorSpec(shape=(40, 40, 1), dtype=tf.uint8)
))
crop_only_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=True)
train_preprocess_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=False)
# This works
crop_dataset = dataset.map(crop_only_fn)
# This fails: ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable
train_dataset = dataset.map(train_preprocess_fn)
附带说明,Keras
增强层通常用作您计划训练的模型的一部分。您也可以使用tf.image
函数,例如tf.image.central_crop
、tf.image.random_flip_left_right
甚至tfa.image.rotate
。
更新 1: 您收到了 cmets 中提到的错误,因为据记录 here,层 tf.keras.layers.RandomFlip
和 tf.keras.layers.RandomRotation
仅在训练。所以尝试使用其他方法:
import functools
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
def preprocess(image, label, cropped_image_size, cropped_label_size, skip_augmentations=False):
x = tf.cast(image, dtype=tf.float32)
y = tf.cast(label, dtype=tf.float32)
x_size = cropped_image_size
y_size = cropped_label_size
if not skip_augmentations:
x = tf.image.random_flip_left_right(x)
y = tf.image.random_flip_left_right(y)
x = tfa.image.rotate(x, 90, fill_mode='constant')
y = tfa.image.rotate(y, 90, fill_mode='constant')
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
return tf.cast(x, dtype=tf.uint8), tf.cast(y, dtype=tf.uint8)
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype=tf.uint8),
tf.TensorSpec(shape=(40, 40, 1), dtype=tf.uint8)
))
crop_only_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=True)
train_preprocess_fn = functools.partial(preprocess, cropped_image_size=50, cropped_label_size=25, skip_augmentations=False)
crop_dataset = dataset.map(crop_only_fn)
train_dataset = dataset.map(train_preprocess_fn)
image, _ = next(iter(train_dataset.take(1)))
plt.imshow(image.numpy())
我排除了tf.keras.preprocessing.image.random_rotation
,因为它现在似乎不适用于张量。
【讨论】:
我尝试使用 TF 2.7(将数据转换为 float32 后),映射语句无异常执行。但是,问题似乎转移到了另一个地方:为两个数据集创建一个迭代器并使用 next() 举一个例子适用于crop_dataset,但不适用于train_dataset。关于errors_impl.NotFoundError NOT FOUND node random_flip_1/stateful_uniform_full_int/RngReadAndSkip。【参考方案2】:正如我评论的那样,您提到的错误我没有发现可重现。但是,它只需要在 __init___
方法中初始化增强层。
ValueError:张量类型的变量初始化器必须被包装 构建函数时在 init_scope 或可调用(例如,
tf.Variable(lambda : tf.truncated_normal([10, 40]))
)中。请归档 如果此限制给您带来不便,请提出功能请求。
这是完整的工作代码。
def data_gen():
for i in range(10):
x = np.random.random(size=(80, 80, 3)) * 255 # rgb image
x = x.astype('uint8')
y = np.random.random(size=(40, 40, 1)) * 255 # downsized mono image
y = y.astype('uint8')
yield x, y
class Augment(tf.keras.layers.Layer):
def __init__(self, seed=42):
super().__init__()
self.flip_a = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
self.flip_b = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
self.rot_a = tf.keras.layers.RandomRotation(factor=1.0,
fill_mode='constant', seed=seed)
self.rot_b = tf.keras.layers.RandomRotation(factor=1.0,
fill_mode='constant', seed=seed)
def call(self, inputs, labels):
x = self.flip_a(inputs)
x = self.rot_a(x)
y = self.flip_b(labels)
y = self.rot_b(y)
return x, y
def preprocess(image, label, cropped_image_size, cropped_label_size):
x = image
y = label
x_size = cropped_image_size
y_size = cropped_label_size
x = tf.cast(x, dtype=tf.float32)
y = tf.cast(y, dtype=tf.float32)
x = tf.keras.layers.CenterCrop(x_size, x_size)(x)
y = tf.keras.layers.CenterCrop(y_size, y_size)(y)
x = tf.cast(x, dtype=tf.uint8)
y = tf.cast(y, dtype=tf.uint8)
return x, y
数据
dataset = tf.data.Dataset.from_generator(data_gen, output_signature=(
tf.TensorSpec(shape=(80, 80, 3), dtype='uint8'),
tf.TensorSpec(shape=(40, 40, 1), dtype='uint8')
))
测试 1
crop_only_fn = functools.partial(preprocess,
cropped_image_size=50,
cropped_label_size=25)
# This works
crop_dataset = dataset.map(crop_only_fn)
x, y = next(iter(crop_dataset))
x.shape, y.shape
(TensorShape([50, 50, 3]), TensorShape([25, 25, 1]))
测试 2
train_preprocess_fn = functools.partial(preprocess,
cropped_image_size=50,
cropped_label_size=25)
train_dataset = dataset.map(train_preprocess_fn)
train_dataset = train_dataset.map(Augment()) # < calling now.
x, y = next(iter(train_dataset))
x.shape, y.shape
(TensorShape([50, 50, 3]), TensorShape([25, 25, 1]))
【讨论】:
以上是关于Keras 扩充不适用于 tf.data.Dataset 映射的主要内容,如果未能解决你的问题,请参考以下文章
Keras 函数(K.function)不适用于 RNN(提供的代码)
TFLite 转换器:为 keras 模型实现的 RandomStandardNormal,但不适用于纯 TensorFlow 模型
使用 dropout (TF2.0) 时,可变批量大小不适用于 tf.keras.layers.RNN?
如何在 tf.data.Dataset 中输入不同大小的列表列表