规范化 tf.data.Dataset

Posted

技术标签:

【中文标题】规范化 tf.data.Dataset【英文标题】:Normalize tf.data.Dataset 【发布时间】:2021-10-05 23:29:50 【问题描述】:

我有一个tf.data.Dataset 的图像,其输入形状(批量大小,128、128、2)和目标形状(批量大小,128、128、1),其中输入是 2 通道图像(复杂-具有代表实部和虚部的两个通道的值图像)和目标是1通道图像(实值图像)。 我需要通过首先从它们中删除它们的平均图像然后将它们缩放到(0,1)范围来规范化输入和目标图像。如果我没记错的话,tf.data.Dataset 一次只能处理一个批次,而不是整个数据集。所以我从remove_meanpy_function中的批次中的每个图像中删除批次的平均图像,然后通过减去它的最小值并除以它的最大值和最小值之差来将每个图像缩放到(0,1) py_function linear_scaling 中的值。但是在应用函数之前和之后从数据集中打印输入图像中的最小值和最大值之后,图像值没有变化。 谁能建议这可能出了什么问题?

def remove_mean(image, target):
    image_mean = np.mean(image, axis=0)
    target_mean = np.mean(target, axis=0)
    image = image - image_mean
    target = target - target_mean
    return image, target

def linear_scaling(image, target):
    image_min = np.ndarray.min(image, axis=(1,2), keepdims=True)
    image_max = np.ndarray.max(image, axis=(1,2), keepdims=True)
    image = (image-image_min)/(image_max-image_min)

    target_min = np.ndarray.min(target, axis=(1,2), keepdims=True)
    target_max = np.ndarray.max(target, axis=(1,2), keepdims=True)
    target = (target-target_min)/(target_max-target_min)
    return image, target

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))

train_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))

train_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))


Output -

tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)

【问题讨论】:

【参考方案1】:

map 不是就地操作,所以当您执行train_dataset.map(....) 时,您的train_dataset 不会改变。

train_dataset = train_dataset.map(...)

【讨论】:

以上是关于规范化 tf.data.Dataset的主要内容,如果未能解决你的问题,请参考以下文章

002.tf.data.DataSet

建议在 tensorflow 2.0 中调试 `tf.data.Dataset` 操作

如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?

如何在 tf.data.Dataset 中输入不同大小的列表列表

002.tf.data.DataSet

002.tf.data.DataSet