如何强制 TensorFlow 在 float16 下运行?

Posted

技术标签:

【中文标题】如何强制 TensorFlow 在 float16 下运行?【英文标题】:How to Force Tensorflow to Run under float16? 【发布时间】:2019-06-16 03:29:20 【问题描述】:

我正在通过定义一个由 keras 的 tf 后端和一些 tf 的张量运算符自己编写的新类来构建一个带有自定义激活函数的 Keras 序列模型。我把自定义激活函数放在../keras/advanced_activation.py.

我打算使用 float16 精度运行它。如果没有自定义函数,我可以使用以下方法轻松地在 float32 和 float16 之间进行选择:

        if self.precision == 'float16':
            K.set_floatx('float16')
            K.set_epsilon(1e-4)
        else:
            K.set_floatx('float32')
            K.set_epsilon(1e-7)

然而,当我的模型中包含自定义函数时,即使我选择了 float16,tf 似乎仍然存在于 float32 中。我知道 tf 默认在 flat32 下运行,所以我的问题是:

    在同一个文件中还有几个内置的激活函数,Keras 是如何让它们在 float16 下运行的,以便我可以做同样的事情?有一个 tf 方法 tf.dtypes.cast(...),我可以在我的自定义函数中使用它来强制 tf 吗?这些内置函数中没有这样的演员表。

    另外,如何通过使用 Keras 和 tf 作为后端来强制 tf 在 float16 下直接运行?

非常感谢。

【问题讨论】:

作为一种肮脏的解决方法(或进一步调试它的方法),我可能会在每次应用自定义函数时推荐 tf.dtypes.cast() 。如果没有,如果您更详细地描述您的自定义函数,它可能会很有用。是纯 TF 还是涉及调用 C++? 纯粹是用 tf.例如,一个段看起来像inputss = tf.where(tf.math.logical_and(tf.greater(orig, 0), tf.less(orig, 0.25)), 0.25 / (1+tf.exp(-self.sharp*((inputss-0.125)/0.25))), inputss) 我想我会尝试强制转换,但是你知道我可以通过哪种方式判断是否在我的函数中应用了强制转换,即我可以使用哪个 tf 变量来调用是否应用强制转换的条件操作? 使用 tf.constant 将这些常量(如 0.125 和 0.25)包装在所需的 dtype 中。 TF 可能会向上转换为 float32,因为这些默认 dtype 【参考方案1】:

我通过调试得到了答案。教训是

首先,tf.dtypes.cast(...) 有效。

其次,我可以在我的自定义激活函数中指定第二个参数来指示 cast(...) 的数据类型。以下是相关代码

第三,我们不需要 tf.constant 来表示那些常量的数据类型

第四,我的结论是,在 custom_activation.py 中添加自定义函数是定义我们自己的层/激活的最简单方法,只要它在任何地方都是可微的,或者至少是分段可微的并且在接合处没有不连续性。

# Quadruple Piece-Wise Constant Function
class MyFunc(Layer):


    def __init__(self, sharp=100, DataType = 'float32', **kwargs):
        super(MyFunc, self).__init__(**kwargs)
        self.supports_masking = True
        self.sharp = K.cast_to_floatx(sharp)
        self.DataType = DataType

    def call(self, inputs):

        inputss = tf.dtypes.cast(inputs, dtype=self.DataType)
        orig = inputss
        # some calculations
        return # my_results

    def get_config(self):
        config = 'sharp': float(self.sharp), 
                  'DataType': self.DataType
        base_config = super(MyFunc, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

感谢 @y.selivonchyk 与我进行有价值的讨论,感谢 @Yolo Swaggins 的贡献。

【讨论】:

以上是关于如何强制 TensorFlow 在 float16 下运行?的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow中float32模型强制转为float16半浮点模型

用于 VGG19 模型参数的 Tensorflow Float16

是否可以在不强制转换的情况下初始化 float32 或 float16 的随机数组?

如何在 pytorch 和 tensorflow 中使用张量核心?

03 使用Tensorflow做计算题

tensorflow fp16训练