如何更改 tf.data.Dataset 中数据的 dtype?
Posted
技术标签:
【中文标题】如何更改 tf.data.Dataset 中数据的 dtype?【英文标题】:How to change the dtype of data in tf.data.Dataset? 【发布时间】:2021-08-24 22:59:16 【问题描述】:我有一个使用此 API 从目录加载的数据集
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.3,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
我想更改数据类型并加快训练速度
我试过了,但是没用
for image_batch, labels_batch in train_ds:
image_batch = tf.cast(image_batch,tf.int16)
【问题讨论】:
您不应将x
类型转换为int
类型(可能会遇到数值不稳定),而应考虑使用混合精度 技术来加快训练速度。
我该怎么做?
【参考方案1】:
只需为您的数据集应用map 方法:
val_ds.map(lambda x, y: (tf.cast(x, tf.int16), y))
【讨论】:
以上是关于如何更改 tf.data.Dataset 中数据的 dtype?的主要内容,如果未能解决你的问题,请参考以下文章
如何将 tf.data.Dataset 与 kedro 一起使用?
如何在 keras 自定义回调中访问 tf.data.Dataset?
Tensorflow:如何查找 tf.data.Dataset API 对象的大小