如何查看 .tflite 文件中的权重?

Posted

技术标签:

【中文标题】如何查看 .tflite 文件中的权重?【英文标题】:How can I view weights in a .tflite file? 【发布时间】:2019-02-06 06:21:13 【问题描述】:

我得到了 MobileNet 的预训练 .pb 文件,发现它没有被量化,而完全量化的模型应该转换为 .tflite 格式。由于我不熟悉移动应用程序开发工具,如何从 .t​​flite 文件中获取 MobileNet 的完全量化权重。更准确地说,如何提取量化参数并查看其数值?

【问题讨论】:

【参考方案1】:

我也在研究 TFLite 的工作原理。我发现的可能不是最好的方法,我将不胜感激任何专家意见。这是我到目前为止使用flatbuffer python API 发现的。

首先,您需要使用 flatbuffer 编译架构。输出将是一个名为 tflite 的文件夹。

flatc --python tensorflow/contrib/lite/schema/schema.fbs

然后你可以加载模型并获得你想要的张量。 Tensor 有一个名为 Buffer() 的方法,根据架构,

引用模型根目录的缓冲区表的索引。

因此它会将您指向数据的位置。

from tflite import Model
buf = open('/path/to/mode.tflite', 'rb').read()
model = Model.Model.GetRootAsModel(buf, 0)
subgraph = model.Subgraphs(0)
# Check tensor.Name() to find the tensor_idx you want
tensor = subgraph.Tensors(tensor_idx) 
buffer_idx = tensor.Buffer()
buffer = model.Buffers(buffer_idx)

之后你就可以通过调用buffer.Data()来读取数据了

参考: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema.fbs https://github.com/google/flatbuffers/tree/master/samples

【讨论】:

schema.fbs file 中,有一个名为Model 的表,其中有一个名为description 的字段。我想在这个字段中写一个字符串(比如说一个单行模型描述)。然后如何加载 tflite 模型,并使用这些额外的元数据更新 tflite 文件?感谢一些帮助。【参考方案2】:

Netron 模型查看器具有漂亮的数据视图和导出功能,以及漂亮的网络图视图。 https://github.com/lutzroeder/netron

【讨论】:

【参考方案3】:

使用 TensorFlow 2.0,您可以使用以下脚本提取权重和有关张量的一些信息(形状、dtype、名称、量化) - 灵感来自 TensorFlow documentation

import tensorflow as tf
import h5py


# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="v3-large_224_1.0_uint8.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


# get details for each layer
all_layers_details = interpreter.get_tensor_details() 


f = h5py.File("mobilenet_v3_weights_infos.hdf5", "w")   

for layer in all_layers_details:
     # to create a group in an hdf5 file
     grp = f.create_group(str(layer['index']))

     # to store layer's metadata in group's metadata
     grp.attrs["name"] = layer['name']
     grp.attrs["shape"] = layer['shape']
     # grp.attrs["dtype"] = all_layers_details[i]['dtype']
     grp.attrs["quantization"] = layer['quantization']

     # to store the weights in a dataset
     grp.create_dataset("weights", data=interpreter.get_tensor(layer['index']))


 f.close()

【讨论】:

看起来 tflite 将索引分配给图层的顺序与模型中图层的排序顺序不同。相反,它按名称对图层列表进行排序,然后将索引分配给排序列表。那么如何恢复正确的层序列呢?有什么解决方法吗? (我正在使用量化的 Mobilenetv2 模型)【参考方案4】:

您可以使用 Netron 应用程序查看它 macOS:下载 .dmg 文件或运行 brew install netron

Linux:下载 .AppImage 文件或运行 snap install netron

Windows:下载 .exe 安装程序或运行 winget install netron

浏览器:启动浏览器版本。

Python 服务器:运行 pip install netron 和 netron [FILE] 或 netron.start('[FILE]')。

【讨论】:

以上是关于如何查看 .tflite 文件中的权重?的主要内容,如果未能解决你的问题,请参考以下文章

Tensor Flow PB文件量化到TFLITE

模型量化原理及tflite示例

Tensor Flow V2:将Tensor Flow H5模型文件转换为tflite

Tensor Flow V2:将Tensor Flow H5模型文件转换为tflite

如何在脚本中加载 tflite 模型?

如何从 tflite 模型输出形状 [1, 28, 28,1] 的数组作为 android 中的图像