如何更改 tf.data.Dataset 中数据的 dtype?

Posted

技术标签:

【中文标题】如何更改 tf.data.Dataset 中数据的 dtype?【英文标题】:How to change the dtype of data in tf.data.Dataset? 【发布时间】:2021-08-24 22:59:16 【问题描述】:

我有一个使用此 API 从目录加载的数据集

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.3,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

我想更改数据类型并加快训练速度

我试过了,但是没用

for image_batch, labels_batch in train_ds:
  image_batch = tf.cast(image_batch,tf.int16)

【问题讨论】:

您不应将x 类型转换为int 类型(可能会遇到数值不稳定),而应考虑使用混合精度 技术来加快训练速度。 我该怎么做? 【参考方案1】:

只需为您的数据集应用map 方法:

val_ds.map(lambda x, y: (tf.cast(x, tf.int16), y))

【讨论】:

以上是关于如何更改 tf.data.Dataset 中数据的 dtype?的主要内容,如果未能解决你的问题,请参考以下文章

如何将 tf.data.Dataset 与 kedro 一起使用?

如何在 keras 自定义回调中访问 tf.data.Dataset?

Tensorflow:如何查找 tf.data.Dataset API 对象的大小

如何按特定值过滤 tf.data.Dataset?

如何在 tensorflow tf.data.Dataset 中使用 cv2 图像增强功能?

如何使用 tf.data.Dataset 对象的 map 方法删除或省略数据?