训练后使用量化权重进行 keras 模型评估
Posted
技术标签:
【中文标题】训练后使用量化权重进行 keras 模型评估【英文标题】:keras model evaluation with quantized weights post training 【发布时间】:2019-03-20 00:27:24 【问题描述】:我有一个在 keras 中训练的模型,并保存为 .h5 文件。该模型使用带有 tensorflow 后端的单精度浮点值进行训练。现在我想实现一个硬件加速器,它在 Xilinx FPGA 上执行卷积操作。但是,在决定要在 FPGA 上使用的定点位宽之前,我需要通过将权重量化为 8 位或 16 位数字来评估模型的准确性。我遇到了tensorflow quantise,但我不确定如何从每一层获取权重,量化它并将其存储在一个 numpy 数组列表中。在所有层都被量化之后,我想将模型的权重设置为新形成的量化权重。有人可以帮我做这个吗?
这是我迄今为止尝试将精度从 float32 降低到 float16 的方法。请让我知道这是否是正确的方法。
for i in range(len(w_orginal)):
temp_shape = w_orginal[i].shape
print('Shape of index: '+ str(i)+ 'array is :')
print(temp_shape)
temp_array = w_orginal[i]
temp_array_flat = w_orginal[i].flatten()
for j in range(len(temp_array)):
temp_array_flat[j] = temp_array_flat[j].astype(np.float16)
temp_array_flat = temp_array_flat.reshape(temp_shape)
w_fp_16_test.append(temp_array_flat)
【问题讨论】:
【参考方案1】:对不起,我对 tensorflow 不熟悉,所以我不能给你代码,但也许我在量化 caffe 模型方面的经验可能有意义。
如果我理解正确,您有一个 tensorflow 模型 (float32),您希望将其量化为 int8 并将其保存在 numpy.array
中。
首先,你应该读取每一层的所有权重,可能是 python 列表或numpy.array
或其他,没关系。
然后,量化算法会显着影响准确性,您必须为您的模型选择最佳的算法。然而,这些算法有一个共同的核心——规模。您需要做的就是将所有权重缩放到 -127 到 127(int8),就像没有 bias
的 scale
层一样,并记录比例因子。
也就是说,如果要在FPGA上实现,数据也要量化。这里又出现了一个新问题——int8 * int8的结果是一个int16,很明显是溢出了。
为了解决这个问题,我们创建了一个新参数 -- shift -- 将 int16 结果移回 int8。注意,shift
参数不会是常数 8,假设我们有 0 * 0 = 0,我们根本不需要移动结果。
最后一个我们要考虑的问题是,如果网络太深,层结果可能会溢出,因为一些不合理的scale
参数,所以我们不能直接量化每一层而不考虑其他层。
在FPGA上完成所有网络之后,如果您想将int8反量化为float32,只需使用最后一个比例参数(最终结果)进行一些mul/div(取决于您如何定义scale
)。
这是一个基本的量化算法,其他像tf.quantization
可能有更高的精度。现在我们有了量化模型,你可以把它保存成你喜欢的任何东西,这并不难。
附:为什么是麻木的? bin 文件最适合 FPGA,不是吗?
而且,您对在 FPGA 上实现 softmax 有什么想法吗?我对此感到困惑......
【讨论】:
是的...这就是我正在尝试的...但是很抱歉,我不明白如何将 float32 数字缩放到 int8 范围内的 -127 到127... 这让我有些困惑。接下来是溢出的处理...如果结果大于127或小于-127...我只是将其设置为可能的最大值...这就是我正在尝试的...但是keras内部采用了这些仅作为 float 32 的值 量化就像一个倒退的过程,你可以尝试将结果缩放到int8,然后你就会知道缩放后的输入数据的范围,而输入的数据就是上一层的结果…… 对不起,我不知道如何从 keras 获取 float32 数字……但我认为这是 keras 应该具备的基本功能。尝试在用户指南中找到它?以上是关于训练后使用量化权重进行 keras 模型评估的主要内容,如果未能解决你的问题,请参考以下文章
:模型训练和预测的三种方法(fit&tf.GradientTape&train_step&tf.data)