将基于 TensorFlow GraphDef 的模型导入 TensorFlow.js
Posted 黑胡桃实验室
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了将基于 TensorFlow GraphDef 的模型导入 TensorFlow.js相关的知识,希望对你有一定的参考价值。
基于 TensorFlow GraphDef 的模型(通常通过 Python API 创建)可以采用以下格式之一保存:
TensorFlow SavedModel(保存模型)
Frozen Model(冷冻模型)
Session Bundle(会话包)
Tensorflow Hub module(TensorFlow Hub 模块)
上述所有格式都可以通过 TensorFlow.js 转换器 转换为 TensorFlow.js Web的友好形式,可以直接加载到 TensorFlow.js 中进行推理。
(注意: TensorFlow 已经弃用 Session Bundle 会话包格式,请将模型迁移到 savedmodel 格式。)
1要求
转换过程需要 Python 环境;您可能需要 pipenv 或者 virtualenv 来保持一个独立的环境,要安装转换器,请运行以下命令:
pip install tensorflowjs
将 TensorFlow 模型导入 TensorFlow.js 需要两个步骤。 首先,将现有模型转换为 TensorFlow.js Web 格式,然后将其加载到 TensorFlow.js 中。
2步骤一:将现有 TensorFlow 模型转换为 TensorFlow.js Web 格式
运行 pip 包提供的转换器脚本:
用法: 以 SavedModel(保存模型)为例:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
以 Frozen model(冷冻模型)为例:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
以 Session bundle model(会话包模型)为例:
tensorflowjs_converter \
--input_format=tf_session_bundle \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/session_bundle \
/mobilenet/web_model
以 TensorFlow Hub module(TensorFlow Hub 模块)为例:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
位置参数 |
描述 |
input_path | SavedModel 目录、session bundle 目录、frozen model 文件或 TensorFlow Hub module handle或完整的 url 链接。 |
output_path | 所有输出项目的路径。 |
选项 |
描述 |
--input_format | 输入模型的格式,使用 tf_saved_model 表示 SavedModel ,使用 tf_frozen_model 表示 frozen model,使用 tf_session_bundle 表示 session bundle, 使用 tf_hub 表示 TensorFlow Hub module , 使用 keras 表示 Keras HDF5。 |
--output_node_names | 输出节点的名称,以逗号分隔。 |
--saved_model_tags | 仅适用于 SavedModel 转换,以逗号分隔格式加载 MetaGraphDef 的标签。 默认为 serve。 |
--signature_name | 仅适用于 TensorFlow Hub 模块转换,需要加载签名。默认为 default。参考 https://www.tensorflow.org/hub/common_signatures/ 。 |
使用以下命令获取详细帮助信息:
tensorflowjs_converter --help
转换器生成的文件
上面的转换脚本生成 3 种类型的文件:
web_model.pb
(dataflow graph,数据流图)weights_manifest.json
(权重清单文件)group1-shard\*of\*
(二进制权重文件的集合)
例如,以下是 MobileNet 模型转换和提供的位置:
https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/optimized_model.pb
https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/weights_manifest.json
https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard1of5
...
https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard5of5
3步骤二:在浏览器中加载和运行
安装 tfjs-converter npm 包
yarn add@tensorflow/tfjs
或者 npm install@tensorflow/tfjs
实例化 FrozenModel class 并进行推理。
import * as tf from '@tensorflow/tfjs';
import {loadFrozenModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'https://.../mobilenet/web_model.pb';
const WEIGHTS_URL = 'https://.../mobilenet/weights_manifest.json';
const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL);
const cat = document.getElementById('cat');
model.execute({input: tf.fromPixels(cat)});
查看我们正在运行的 MobileNet demo(MobileNet示例)。
如果您的服务器请求访问模型文件的凭据,那么您可以提供可选的 RequestOption 参数,它将直接传递给 fetch 函数调用。
const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL,
{credentials: 'include'});
4支持的 Ops(操作)
目前 TensorFlow.js 只支持有限的 TensorFlow 操作。请参阅 完整列表。如果您的模型使用任何不受支持的操作,这个 tensorflowjs_converter 脚本将执行失败并生成模型中不支持的操作的列表。请提交 issues 告诉我们您还需要我们给您哪些支持。
5只加载权重
如果您只想加载权重,可以使用以下代码段。
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");
✄----------------------------------
本文是黑胡桃实验室与 Googler、GDE 协作翻译的 TensorFlow.js 中文尝鲜版,英文版首发于 https://js.tensorflow.org/
探索AI
触摸科技
黑胡桃实验室
杭州·赛银国际广场4幢1楼
BlackWalnut Labs.
以上是关于将基于 TensorFlow GraphDef 的模型导入 TensorFlow.js的主要内容,如果未能解决你的问题,请参考以下文章
如何使用来自 Google AutoML 视觉分类的 TensorFlow Frozen GraphDef (single saved_model.pb) 进行推理和迁移学习