Tensor Flow PB文件量化到TFLITE

Posted 17岁boy想当攻城狮

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensor Flow PB文件量化到TFLITE相关的知识,希望对你有一定的参考价值。

前言

一般在Slim上进行完迁移训练之后我们想将它量化到TFLITE需要先将CKPT量化到PB,在将PB量化到TFLITE,这个原因是因为格式的原因,CKPT是使用多个文件存储模型不同的权重、变量、算子,网格结构,而PB仅一个文件存储所有信息,在tf提供了tf.lite.TFLiteConverter来量化到TFLITE,而这个量化工具只支持PB,所以一般需要将CKPT转化为PB在转TFLITE。

实现代码

首先包含相应头模块

import tensorflow as tf
import io
import PIL
import numpy as np

说一下TFLiteConverter的依赖,TFLiteConverter依赖三个参数:

inference_input_type:输出类型
inference_output_type:输入类型
optimizations:优化器
representative_dataset:Tensor Flow量化里representative_dataset参数是什么意思?_17岁boy的博客-CSDN博客

首先我们为rep编写用于校准TFLITE变量的验证集数据,注意它需要yield函数,所以我们先写一个python函数:

def rep():

我们的验证集是用tfrecord格式存取的,内部其实就是protobuf,我们用tf的protobuf模块来解码:

 #需要是验证集的数据源
record_iterator = tf.python_io.tf_record_iterator(path='/home/xxxx/models/research/slim/ci_data/cifar10_train.tfrecord')

然后定义一个count,我们要循环取验证集数据,所以用这个变量来做索引:

count = 0

然后for循环遍历

for string_record in record_iterator:

这里用Example来解析Protobuf

example = tf.train.Example()
example.ParseFromString(string_record)

解码之后直接就可以按数据流取出来:

image_stream = io.BytesIO(example.features.feature['image/encoded'].bytes_list.value[0])

然后我们用PIL以数据流的形式打开IMAGE

image = PIL.Image.open(image_stream)

接下来我们就要开始量化了,这里的量化图像size就为了更好的量化TFLITE里面浮点数数组大小,控制size,然后转化为灰度图,即通道数为1

image = image.resize((96,96))
image = image.convert('L')

然后我们对其增加维度,因为使用了resize所以维度大小现在是:96x96,二维数组,而TFLITE量化时要求4维,并且顺序是这样的:(1,96,96,1)

第一维用于存储batch-size,后面两维用于存储图像像素,最后一维是通道数,用于表示多少字节为一个像素点,使用expand_dims来扩维,axis=2即在第2维上扩充一维出来,这个维度的大小元素为1

0就是在第1维的前面扩充一维,在扩维时python以1为起始坐标,不同于C语言的0。

array = np.array(image)
array = np.expand_dims(array,axis=2)
array = np.expand_dims(array,axis=0)

当然你也可以用reshape来扩维,效果是一样的,扩充一维

array.reshape((array.shape[0], array.shape[1], array.shape[2], 1))

扩维后的数据我们不需要给值,这个值TFLITE会从PB中获取并写入到里面去,然后自动量化,我们提供的仅仅是一个四维的数组,其中两维是图像数据,如:

(10,96,96,1)

表示的是每次输入10个96x96的数组进来,数组里每个1个字节代表一个像素点,通常这个值会给卷积核用,卷积核中有个参数就是通道size,一般会取这个值作为通道size。

最后在数据归一化:

array = ((array / 127.5) - 1.0).astype(np.float32)

然后在使用yield抛出数组,最后在做判断,最大用300张数据集量化

yield([array])
count += 1
if count > 300:
    break

最后就是设置参数与保存了:

converter = tf.lite.TFLiteConverter.from_frozen_graph('/home/xxx/work/freezed_cifarnet.pb',['input'],['MobilenetV1/Predictions/Reshape_1'])
converter.inference_input_type = tf.lite.constants.INT8
converter.inference_output_type = tf.lite.constants.INT8
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = rep

#量化并保存
tflite_quant_model = converter.convert()
open("test.tflite","wb").write(tflite_quant_model)

完整代码:

import tensorflow as tf
import io
import PIL
import numpy as np

def rep():
    #需要是验证集的数据源
    record_iterator = tf.python_io.tf_record_iterator(path='/home/zhihao/models/research/slim/ci_data/cifar10_train.tfrecord')
    count = 0
    #将图像从protobu取出来量化成数组
    for string_record in record_iterator:
        example = tf.train.Example()
        example.ParseFromString(string_record)
        #这里是你存放图像数据的消息协议名
        image_stream = io.BytesIO(example.features.feature['image/encoded'].bytes_list.value[0])
        image = PIL.Image.open(image_stream)
        #这里将它固定量化成96x96的数组大小,这样方便优化
        image = image.resize((96,96))
        #量化,L=灰度图,1个bit表示三个像素点
        image = image.convert('L')
        #扩维与归一化
        array = np.array(image)
        array = np.expand_dims(array,axis=2)
        array = np.expand_dims(array,axis=0)
        array = ((array / 127.5) - 1.0).astype(np.float32)
        yield([array])
        count += 1
        #最大量化三百张
        if count > 300:
            break

#你的PB文件,这个文件要是包含神经网络权重的PB文件
converter = tf.lite.TFLiteConverter.from_frozen_graph('/home/xxx/work/freezed_cifarnet.pb',['input'],['MobilenetV1/Predictions/Reshape_1'])
converter.inference_input_type = tf.lite.constants.INT8
converter.inference_output_type = tf.lite.constants.INT8
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = rep

#量化并保存
tflite_quant_model = converter.convert()
open("test.tflite","wb").write(tflite_quant_model)

以上是关于Tensor Flow PB文件量化到TFLITE的主要内容,如果未能解决你的问题,请参考以下文章

Tensor Flow V2:将Tensor Flow H5模型文件转换为tflite

Tensor Flow量化里representative_dataset参数是什么意思?

Tensor Flow量化里representative_dataset参数是什么意思?

Tensor Flow Lite C++ API 介绍

Tensor Flow Lite C++ API 介绍

使用 toco 将假量化 tensorflow 模型(.pb)转换为 tensorflow lite 模型(.tflite)失败