Tensorflow1 和 Tensorflow2 中的批处理

Posted

技术标签:

【中文标题】Tensorflow1 和 Tensorflow2 中的批处理【英文标题】:Batching in Tensorflow1 and Tensorflow2 【发布时间】:2021-07-15 14:56:59 【问题描述】:

我正在尝试将图像单应性代码从 TF1 版本转换为 TF2,只是 TF 脚本转换在这里不起作用。我坚持对数据集进行批处理,因为图像、image_patch 和 image_Indices 具有不同的形状。虽然 TF1 在摄取和批处理数据集包方面没有问题,但 TF2 却遇到了麻烦。

imgs= np.random.rand(11,240,320,3)
pts = np.random.randint(100, size =(11,8))
patch = np.random.rand(11,128,128,1)

imgs = tf.convert_to_tensor(imgs)
pts = tf.convert_to_tensor(pts)
patch = tf.convert_to_tensor(patch)

pts= tf.cast(pts,dtype=tf.float64)

张量流2:

    img_batch,pts_batch,patch_batch = tf.data.Dataset.from_tensor_slices([imgs,pts,patch]).shuffle(buffer_size=batch_size*4)

这里 11 是图像数量,240 和 320 是图像尺寸,3 是通道数。

错误-

tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [11,240,320,3] != values[2].shape = [11,128,128,1] [Op:Pack] name: component_0

张量流1:

tf.compat.v1.train.batch([imgs,pts,patch], batch_size=5)

输出 -

[<tf.Tensor 'batch_2:0' shape=(5, 11, 240, 320, 3) dtype=float64>,
 <tf.Tensor 'batch_2:1' shape=(5, 11, 8) dtype=float64>,
 <tf.Tensor 'batch_2:2' shape=(5, 11, 128, 128, 1) dtype=float64>]

如何在tensorflow2中批量处理不同维度的数据集? 同样在运行时,“tf.compat.v1.train.batch()”在 TF2(tensoflow 版本 2.3)中不起作用,因为它给出了急切的执行错误。

在 TF2 中批处理此类数据集的正确方法是什么?

【问题讨论】:

【参考方案1】:

这里的问题不是批处理,而是tf.data.Dataset 本身的生成。错误是由img_batch,pts_batch,patch_batch = tf.data.Dataset.from_tensor_slices([imgs,pts,patch]) 引起的,不是由.shuffle(batch_size=...) 引起的。

我觉得.from_tensor_slices这里级别太高了,看看tf.data.Dataset.from_generator

【讨论】:

以上是关于Tensorflow1 和 Tensorflow2 中的批处理的主要内容,如果未能解决你的问题,请参考以下文章

Python Tensorflow1.x升级到2.x低阶API实践

Python Tensorflow1.x升级到2.x低阶API实践

Python Tensorflow1.x升级到2.x低阶API实践

tensorflow2.0 安装教程

GitHub标星2000+,如何用30天啃完TensorFlow2.0?

Tensorflow | 绪论