将基于 TensorFlow GraphDef 的模型导入 TensorFlow.js

Posted 黑胡桃实验室

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了将基于 TensorFlow GraphDef 的模型导入 TensorFlow.js相关的知识,希望对你有一定的参考价值。


基于 TensorFlow GraphDef 的模型(通常通过 Python API 创建)可以采用以下格式之一保存:

  1. TensorFlow SavedModel(保存模型)

  2. Frozen Model(冷冻模型)

  3. Session Bundle(会话包)

  4. Tensorflow Hub module(TensorFlow Hub 模块)

上述所有格式都可以通过 TensorFlow.js 转换器 转换为 TensorFlow.js Web的友好形式,可以直接加载到 TensorFlow.js 中进行推理。

(注意: TensorFlow 已经弃用 Session Bundle 会话包格式,请将模型迁移到 savedmodel 格式。)


1要求


转换过程需要 Python 环境;您可能需要 pipenv 或者 virtualenv 来保持一个独立的环境,要安装转换器,请运行以下命令:


     
       
       
     
  1. pip install tensorflowjs


将 TensorFlow 模型导入 TensorFlow.js 需要两个步骤。 首先,将现有模型转换为 TensorFlow.js Web 格式,然后将其加载到 TensorFlow.js 中。


2步骤一:将现有 TensorFlow 模型转换为 TensorFlow.js Web 格式


运行 pip 包提供的转换器脚本:

用法: 以 SavedModel(保存模型)为例:


     
       
       
     
  1. tensorflowjs_converter \

  2.   --input_format=tf_saved_model \

  3.   --output_node_names='MobilenetV1/Predictions/Reshape_1' \

  4.   --saved_model_tags=serve \

  5.   /mobilenet/saved_model \

  6.   /mobilenet/web_model


以 Frozen model(冷冻模型)为例:


     
       
       
     
  1. tensorflowjs_converter \

  2.   --input_format=tf_frozen_model \

  3.   --output_node_names='MobilenetV1/Predictions/Reshape_1' \

  4.   /mobilenet/frozen_model.pb \

  5.   /mobilenet/web_model


以 Session bundle model(会话包模型)为例:


     
       
       
     
  1. tensorflowjs_converter \

  2.   --input_format=tf_session_bundle \

  3.   --output_node_names='MobilenetV1/Predictions/Reshape_1' \

  4.   /mobilenet/session_bundle \

  5.   /mobilenet/web_model


以 TensorFlow Hub module(TensorFlow Hub 模块)为例:


     
       
       
     
  1. tensorflowjs_converter \

  2.   --input_format=tf_hub \

  3.   'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \

  4.   /mobilenet/web_model


位置参数

描述

input_path SavedModel 目录、session bundle 目录、frozen model 文件或 TensorFlow Hub module handle或完整的 url 链接。
output_path 所有输出项目的路径。


选项

描述

--input_format 输入模型的格式,使用 tf_saved_model 表示 SavedModel,使用 tf_frozen_model 表示 frozen model,使用 tf_session_bundle 表示 session bundle, 使用 tf_hub 表示 TensorFlow Hub module , 使用 keras 表示 Keras HDF5。
--output_node_names 输出节点的名称,以逗号分隔。
--saved_model_tags 仅适用于 SavedModel 转换,以逗号分隔格式加载 MetaGraphDef 的标签。 默认为 serve
--signature_name 仅适用于 TensorFlow Hub 模块转换,需要加载签名。默认为 default。参考 https://www.tensorflow.org/hub/common_signatures/ 。


使用以下命令获取详细帮助信息:


     
       
       
     
  1. tensorflowjs_converter --help


转换器生成的文件

上面的转换脚本生成 3 种类型的文件:

  • web_model.pb (dataflow graph,数据流图)

  • weights_manifest.json (权重清单文件)

  • group1-shard\*of\* (二进制权重文件的集合)

例如,以下是 MobileNet 模型转换和提供的位置:

     
       
       
     
  1. https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/optimized_model.pb

  2. https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/weights_manifest.json

  3. https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard1of5

  4. ...

  5. https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard5of5


3步骤二:在浏览器中加载和运行


安装 tfjs-converter npm 包

yarn add@tensorflow/tfjs 或者 npm install@tensorflow/tfjs

实例化 FrozenModel class 并进行推理。


     
       
       
     
  1. import * as tf from '@tensorflow/tfjs';

  2. import {loadFrozenModel} from '@tensorflow/tfjs-converter';

  3. const MODEL_URL = 'https://.../mobilenet/web_model.pb';

  4. const WEIGHTS_URL = 'https://.../mobilenet/weights_manifest.json';

  5. const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL);

  6. const cat = document.getElementById('cat');

  7. model.execute({input: tf.fromPixels(cat)});


查看我们正在运行的 MobileNet demo(MobileNet示例)。


如果您的服务器请求访问模型文件的凭据,那么您可以提供可选的 RequestOption 参数,它将直接传递给 fetch 函数调用。


     
       
       
     
  1. const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL,

  2.   {credentials: 'include'});


4支持的 Ops(操作)


目前 TensorFlow.js 只支持有限的 TensorFlow 操作。请参阅 完整列表。如果您的模型使用任何不受支持的操作,这个 tensorflowjs_converter 脚本将执行失败并生成模型中不支持的操作的列表。请提交 issues 告诉我们您还需要我们给您哪些支持。


5只加载权重


如果您只想加载权重,可以使用以下代码段。


      
        
        
      
  1. import * as tf from '@tensorflow/tfjs';

  2. const weightManifestUrl = "https://example.org/model/weights_manifest.json";

  3. const manifest = await fetch(weightManifestUrl);

  4. this.weightManifest = await manifest.json();

  5. const weightMap = await tf.io.loadWeights(

  6.       this.weightManifest, "https://example.org/model");


✄----------------------------------

本文是黑胡桃实验室与 Googler、GDE 协作翻译的 TensorFlow.js 中文尝鲜版,英文版首发于 https://js.tensorflow.org/


探索AI

触摸科技

黑胡桃实验室

杭州·赛银国际广场4幢1楼

BlackWalnut Labs. 


以上是关于将基于 TensorFlow GraphDef 的模型导入 TensorFlow.js的主要内容,如果未能解决你的问题,请参考以下文章

如何使用来自 Google AutoML 视觉分类的 TensorFlow Frozen GraphDef (single saved_model.pb) 进行推理和迁移学习

“检查您的 GraphDef 解释二进制文件是不是与您的 GraphDef 生成二进制文件是最新的。”

带你了解TensorFlow pb模型常用处理方法

干货基于TensorFlow卷积神经网络的短期股票预测

建设基于TensorFlow的深度学习环境

Tensorflow Serving 初探