所有的Tensorflow模型都可以嵌入到移动设备

Posted AI前线

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了所有的Tensorflow模型都可以嵌入到移动设备相关的知识,希望对你有一定的参考价值。

作者|第四范式-陈迪豪
译者|alexyi
编辑|Emily
AI 前线导读:我们通常将 TensorFlow 模型保存为可以被 TesnorFlow Serving 和 Simple TensorFlow Serving 调用的本地文件。然后移动设备可以通过在线访问这些服务实现移动端的一些应用。

如果移动应用可以在不通过在线请求的情况下访问模型并进行推理,这将会极大地提高 TensorFlow 的灵活性。TensorFlow Mobile 以及 TensorFlow Lite 的官方文档给出了一些可以使用图像模型的示例 App,但是实现起来并不简单,并且难以将你的模型嵌入到安卓系统。

这里我们提供了一种通用教程,可以将所有的 TensorFlow 模型嵌入到移动设备。你可以直接克隆 tensorflow template application 的源代码并且针对你自己的机器学习脚本调整模型文件。

更多干货内容请关注微信公众号“AI 前线”,(ID:ai-front)
TensorFlow 模型格式

如果我们要将 TensorFlow 模型嵌入移动设备,首先需要确定模型的格式。正如官方文档所述,TensorFlow 模型有很多格式,例如,Checkpoint,Exporter,SaveModel,Frozen graph 等。

Checkpoint 格式用于在训练时存模型快照以及恢复训练。所以不要将它作为最终的保存格式,尽管它包含了推理所需的全部变量。

SaceModel 格式通常用于基于 TensorFlow Serving 的在线服务。你可以使用 TensorFlow C++ API 和 Python API 加载这种格式的模型。而在实际应用中,在线客户端通常倾向于使用开源、轻量级、RESTful 的服务,即 Simple TensorFlow Serving。

对于移动设备,我们需要使用 GraphDef 对象和 Checkpoint 文件生成 Frozen graph。TensorFLow 提供了导出 GraphDef 对象的 API,你可以通过这段代码来轻松实现。

graph_file_name = "graph.pb"
tf.train.write_graph(sess.graph_def, FLAGS.model_path, graph_file_name, as_text=False)

然后我们可以使用 TensorFlow 库中的 freeze_graph.py 脚本生成二进制 protobuf 格式的 Frozen graph 文件。

python ./freeze_graph.py --input_graph=/Users/tobe/code/tensorflow_template_application/model/graph.pb --input_checkpoint=/Users/tobe/code/tensorflow_template_application/checkpoint/checkpoint.ckpt-200 --output_graph=./frozen_graph.pb --output_node_names=output_keys,output_prediction,output_softmax --input_binary=True

需要注意的是,我们应当为生成的模型指定可用于推理的输出节点名称,不同的 TensorFlow 应用可能会根据其用途更改输出点的名称。

如果你不想自己生成 Frozen graph 文件,可以直接克隆 teansorflow template application 的源代码,其中包含了移动端模型文件。

安卓的 Java 代码

现在我们有了 TensorFlow 模型文件,接下来只需要加载模型并使用我们的数据进行推理就可以了。

多亏了 TensorFlow Mobile 的工作,我们不需要自己编写 C++ 和 JNI 代码来加载 TensorFlow 模型。有一个名为“TensorFlowInferenceInterface”的封装好的类可以用来加载模型并进行推理。

AssetManager assetManager = getAssets();
String MODEL_FILE = "file:///android_asset/tensorflow_template_application_model.pb";
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE);

众所周知,张量 (Tensor) 是 TensorFlow 的核心概念。因此我们需要使用自己构建的张量数据作为输入,而不是原始的图像像素,但是它主要用于 Python 和 C++ 接口。对于安卓客户端,我们需要构造与张量尺寸相同的“nd-array”作为输入。

int[] keysValues = new int[2];
keysValues[0] = 1;
keysValues[1] = 2;
float[] featruesValues = new float[18];
for (int i = 0; i < 18; i++) {
    featruesValues[i] = 1f;
}
String[] inputNames = new String[2];
inputNames[0] = "keys";
inputNames[1] = "features";
inferenceInterface.feed(inputNames[0], keysValues, 2, 1);
inferenceInterface.feed(inputNames[1], featruesValues, 2, 9);

这是一种使用 TensorFlow 模型的普遍方式。如果你需要访问自然语言处理 (NLP) 模型或者图像模型,你也可以将输入解析为 Java 的"nd-array"格式。实例代码展示了怎样为一个 tensorflow_template_application 模型构建张量并针对你的模型修改实际数据。

最后,我们可以使用 Java 的“nd-array”对象进行推理并得到输出。

String[] outputNames = new String[3];
outputNames[0] = "output_keys";
outputNames[1] = "output_prediction";
outputNames[2] = "output_softmax";
inferenceInterface.run(outputNames, logStats);
int[] keysOutputs = new int[2];
long[] predictionOutput = new long[2];
float[] softmaxOutput = new float[4];
inferenceInterface.fetch(outputNames[0], keysOutputs);
inferenceInterface.fetch(outputNames[1], predictionOutput);
inferenceInterface.fetch(outputNames[2], softmaxOutput);

当然,请确保这些张量数据的类型和形状能够与 TensorFlow 的 Python 脚本所定义的模型兼容。得到了模型的输出后,你就可以使用这些输出数组实现其他功能了。

离线推理的完整过程不需要 gRPC 客户端或者任何网络连接。这是一个 TensorFlow 示例模型嵌入安卓设备的教程,不过你可以将它扩展到所有其他模型。本文中所有的代码都来自于 GitHub 上的开源项目。也许,你可能不知道 TensorFlow Mobile 的所有细节,但是仍可以通过 tesnorflow template application 随意尝试编译你自己的安卓客户端。

总结

总的来说,TensorFlow 的移动端模型实现起来并不难。尽管官方的安卓 Demo 应用对于一般情况来说并不够好,但我们可以从这篇文章中学到一些实践经验,并有信心将我们所有的机器学习模型都移植到离线设备上。

如果觉得内容不错,记得给我们「留言」和「点赞」,给编辑鼓励一下!

以上是关于所有的Tensorflow模型都可以嵌入到移动设备的主要内容,如果未能解决你的问题,请参考以下文章

震惊!谷歌正式发布移动端深度学习框架TensorFlow Lite

谷歌移动端深度学习框架 TensorFlow Lite 正式发布

@移动开发者,谷歌发布移动端深度学习框架TensorFlow Lite

移动端目标识别(3)——使用TensorFlow Lite将tensorflow模型部署到移动端(ssd)之Running on mobile with TensorFlow Lite (写的很乱,回

机器学习笔记 - TensorFlow Lite设备端机器学习的模型优化

Tensorflow:如何将预训练模型已经嵌入的数据输入到 LSTM 模型中?