如何通过 tf.data API 使用 Keras 生成器
Posted
技术标签:
【中文标题】如何通过 tf.data API 使用 Keras 生成器【英文标题】:How to use Keras generator with tf.data API 【发布时间】:2019-03-09 05:17:36 【问题描述】:我正在尝试使用 Keras 预处理库中的生成器。我想对此进行试验,因为 Keras 为图像增强提供了强大的功能。但是,我不确定这是否真的可行。
这是我从 Keras 生成器制作 tf 数据集的方法:
def make_generator():
train_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator =
train_datagen.flow_from_directory(train_dataset_folder,target_size=(224, 224), class_mode='categorical', batch_size=32)
return train_generator
train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32)).shuffle(64).repeat().batch(32)
请注意,如果您尝试直接将train_generator
作为tf.data.Dataset.from_generator
的参数,则会出现错误。但是,上述方法不会产生错误。
当我在会话中运行它以检查数据集的输出时,我收到以下错误。
iterator = train_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
for i in range(100):
sess.run(next_element)
找到属于 2 个类别的 1000 张图像。 -------------------------------------------------- ------------------------- InvalidArgumentError Traceback(最近调用 最后的) /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py 在 _do_call(self, fn, *args) 1291 中尝试: -> 1292 return fn(*args) 1293 除了errors.OpError as e:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py 在 _run_fn(feed_dict、fetch_list、target_list、options、run_metadata) 第1276章 -> 1277 个选项,feed_dict,fetch_list,target_list,run_metadata)1278
/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py 在 _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,run_metadata)1366 self._session,选项, feed_dict, fetch_list, target_list, -> 1367 运行元数据)1368
InvalidArgumentError: 不能批量处理不同形状的张量 组件 0。第一个元素的形状为 [32,224,224,3],元素 29 的形状为 形状 [8,224,224,3]。 [[node IteratorGetNext_2 = IteratorGetNextoutput_shapes=[, ], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]
在处理上述异常的过程中,又发生了一个异常:
如果有人对此有任何经验或知道任何替代方法,请告诉我。
更新
在使用了 J.E.K. 的建议后,我能够解决问题
train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32))
但是,当我将 train_dataset
提供给 Keras .fit
方法时,我收到以下错误。
model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)
----------------------------------- ---------------------------- ValueError Traceback(最近一次调用 最后)在() ----> 1 model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split,validation_data,shuffle,class_weight, sample_weight,initial_epoch,steps_per_epoch,validation_steps, **kwargs) 1507 steps_name='steps_per_epoch', 1508 steps=steps_per_epoch, -> 1509 validation_split=validation_split) 1510 1511 # 准备验证数据。
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py 在 _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size、check_steps、steps_name、steps、validation_split) 第948章 949 其他: --> 950 迭代器 = x.make_initializable_iterator() 第951章 952 x = 迭代器
/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py 在 make_initializable_iterator(self, shared_name) 119 与 ops.colocate_with(iterator_resource): 120 初始化器 = gen_dataset_ops.make_iterator(self._as_variant_tensor(), --> 121 迭代器资源) 122 return iterator_ops.Iterator(iterator_resource, 初始化器, 123 self.output_types,self.output_shapes,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_dataset_ops.py 在 make_iterator(dataset, iterator, name) 2542 如果 _ctx 为 None 或 不是 _ctx._eager_context.is_eager: 2543 _, _, _op = _op_def_lib._apply_op_helper( -> 2544 "MakeIterator", dataset=dataset, iterator=iterator, name=name) 2545 return _op 2546 _result = None
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py 在 _apply_op_helper(self, op_type_name, name, **keywords) 348 # 需要将所有参数扁平化成一个列表。 第349章 --> 350 g = ops._get_graph_from_inputs(_Flatten(keywords.values())) 第351章 352 除了 AssertionError as e:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py 在 _get_graph_from_inputs(op_input_list, graph) 5659 图 = graph_element.graph 5660 elif original_graph_element 不是无: -> 5661 _assert_same_graph(original_graph_element, graph_element) 5662 elif graph_element.graph 不是图: 5663 raise ValueError("%s 不是来自传入的图表。" % 图元素)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py 在 _assert_same_graph(original_item, item) 5595 如果 original_item.graph 不是 item.graph: 5596 raise ValueError("%s 必须与 %s 来自同一个图表。" % (item, -> 5597 original_item)) 5598 5599
ValueError: Tensor("IteratorV2:0", shape=(), dtype=resource) 必须是 来自与 Tensor("FlatMapDataset:0", shape=(), dtype=variant)。
这是一个错误还是 Keras fit 方法不应该以这种方式使用?
【问题讨论】:
How to Properly Combine TensorFlow's Dataset API and Keras?的可能重复 关于为什么你需要传递make_generator
而不是train_generator
,docs explain it:“构造函数将可调用作为输入,而不是迭代器。这允许它在生成时重新启动生成器到达终点。它接受一个可选的 args 参数,作为可调用的参数传递。"
【参考方案1】:
我尝试通过一个简单的示例重现您的结果,我发现当在生成器函数和tf.data
中使用批处理时,您会得到不同的输出形状。
Keras 函数train_datagen.flow_from_directory(batch_size=32)
已经返回形状为[batch_size, width, height, depth]
的数据。如果使用tf.data.Dataset().batch(32)
,则输出数据将再次批处理为[batch_size, batch_size, width, height, depth]
。
这可能由于某种原因导致了您的问题。
【讨论】:
感谢您的建议。它有效,我能够输出输出。但是,当我将 train_dataset 提供给 model.fit() 时,它会给出另一个错误。我将使用新错误更新原始帖子。【参考方案2】:不应该
model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)
是
model_regular.fit(train_dataset.make_one_shot_iterator(),steps_per_epoch=1000,epochs=2)
根据this answer。
【讨论】:
以上是关于如何通过 tf.data API 使用 Keras 生成器的主要内容,如果未能解决你的问题,请参考以下文章
如何将 Keras tf.data 与生成器( flow_from_dataframe )一起使用?形成完美的输入管道
如何在 keras 自定义回调中访问 tf.data.Dataset?
如何在Tensorflow中组合feature_columns,model_to_estimator和dataset API
如何更改 tf.data.Dataset 中数据的 dtype?