如何将 tf.keras 与 bfloat16 一起使用

Posted

技术标签:

【中文标题】如何将 tf.keras 与 bfloat16 一起使用【英文标题】:How to use tf.keras with bfloat16 【发布时间】:2019-09-30 16:00:11 【问题描述】:

我正在尝试使用混合精度让 tf.keras 模型在 TPU 上运行。我想知道如何使用 bfloat16 混合精度构建 keras 模型。是这样的吗?

with tf.contrib.tpu.bfloat16_scope():
    inputs = tf.keras.layers.Input(shape=(2,), dtype=tf.bfloat16)
    logits = tf.keras.layers.Dense(2)(inputs)

logits = tf.cast(logits, tf.float32)
model = tf.keras.models.Model(inputs=inputs, outputs=logits)
model.compile(optimizer=tf.keras.optimizers.Adam(.001),
              loss='mean_absolute_error', metrics=[])

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
        model,
        strategy=tf.contrib.tpu.TPUDistributionStrategy(
            tf.contrib.cluster_resolver.TPUClusterResolver(tpu='my_tpu_name')
        )
    )

【问题讨论】:

cloud.google.com/tpu/docs/bfloat16 你能不能请这个.. 该链接未指定如何使用 tf.keras。所有示例都是针对普通张量流的。 你可以用 google colab 试试看。 github.com/tensorflow/tensorflow/issues/26759,截至目前 tf.keras 不支持 bfloat16。 好像说不支持保存hdf5格式的模型。似乎它仍然可以训练模型并以 TF SavedModel 格式保存。 @TensorflowSupport 您收到该错误是因为我为 TPU 输入了一个假名称。您需要在此处输入自己的 URL。 【参考方案1】:

您可以使用下面显示的代码使用bfloat16 Mixed Precisionfloat16 计算和float32 变量)构建Keras 模型。

tf.keras.mixed_precision.experimental.set_policy('infer_float32_vars')

model = tf.keras.Sequential([
    tf.keras.layers.Inputs(input_shape=(2, ), dtype=tf.float16),    
    tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')),
    tf.keras.layers.Dense(10)])

model.compile(optimizer=tf.keras.optimizers.Adam(.001),
              loss='mean_absolute_error', metrics=[])

model.fit(.............)

构建并训练模型后,我们可以使用以下步骤保存模型:

tf.keras.experimental.export_saved_model(model, path_to_save_model)

我们可以使用以下代码加载保存的混合精度 Keras 模型:

new_model = tf.keras.experimental.load_from_saved_model(path_to_save_model)
new_model.summary()

【讨论】:

以上是关于如何将 tf.keras 与 bfloat16 一起使用的主要内容,如果未能解决你的问题,请参考以下文章

tf.keras 模型 多个输入 tf.data.Dataset

TensorFlow2 入门指南 | 11 Keras 与 tf.keras 总体框架介绍

使用 tf.keras.Model.fit 进行训练时如何将自定义摘要添加到 tensorboard

了解tf.keras中模型训练验证的相关方法

无法将 tf.keras 模型正确转换为珊瑚 TPU 的量化格式

神经网络图片基础与tf.keras介绍