导出后加载 TensorFlow 对象检测模型

Posted

技术标签:

【中文标题】导出后加载 TensorFlow 对象检测模型【英文标题】:Loading TensorFlow Object Detection Model After Export 【发布时间】:2021-12-22 12:50:09 【问题描述】:

我已经按照this 官方教程中提供的步骤使用 TensorFlow API 训练了一个对象检测模型。因此,在整个过程结束时,如the exporting step 中所述,我已将模型保存为以下格式。

my_model/
├─ checkpoint/
├─ saved_model/
└─ pipeline.config

我的问题是,一旦模型被保存为这种格式,我该如何加载它并使用它来进行检测?

我可以使用下面的代码通过训练检查点成功地做到这一点。并且在该点之后(我加载生成最佳结果的检查点)导出模型。

# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(PATH_TO_PIPELINE_CONFIG)
model_config = configs['model']
detection_model = model_builder.build(model_config=model_config, is_training=False)

# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(PATH_TO_CKPT).expect_partial()

但是,在生产中,我不打算使用这些检查点。我希望从导出的格式中加载模型。

我尝试了以下命令来加载导出的模型,但我没有运气。它没有返回错误,我可以使用下面的 model 变量进行检测,但是输出(边界框、类、分数)不正确,这让我相信加载中缺少一些步骤过程。

model = tf.saved_model.load(path_to_exported_model)

有什么建议吗?

【问题讨论】:

【参考方案1】:

好的,事实证明,代码是正确的。我用另一个模型(也是一个 EfficientDet)进行了测试,代码有效。导出原始模型时似乎出了点问题,我仍在努力解决。

对于那些寻找答案的人,这里是加载和使用已保存模型的完整代码。

# Loading saved mode.
model = tf.saved_model.load(path_to_exported_model)

# Pre-processing image.
image = tf.image.decode_image(open(path_to_image, 'rb').read(), channels=3)
image = tf.expand_dims(image, 0)
image = tf.image.resize(image, (size_expected_by_model, size_expected_by_model))

# Model expects tf.uint8 tensor, but image is read as tf.float32.
image = tf.image.convert_image_dtype(image, tf.uint8)

# Executing object detection.
detections = model(image)

# Formatting returned detections.
num_detections = int(detections.pop('num_detections'))
detections = key: value[0, :num_detections].numpy()
              for key, value in detections.items()

detections['num_detections'] = num_detections

detections['detection_classes'] = detections['detection_classes'].astype(np.int64)

【讨论】:

【参考方案2】:

检查此链接.....Abdul Rehman 几乎没有 python 代码来运行 save_models 检测以推断图像和视频......我广泛使用这些代码,以检查从TF2 Model Zoo,以及在自定义数据集上训练的模型......

https://github.com/abdelrahman-gaber/tf2-object-detection-api-tutorial

【讨论】:

以上是关于导出后加载 TensorFlow 对象检测模型的主要内容,如果未能解决你的问题,请参考以下文章

使用重新训练的 Tensorflow 对象检测模型使用 snpe 进行 pb 到 dlc 转换失败

导出 LD_LIBRARY_PATH 后加载库事件时出错

页面加载后加载mysql查询

如何在 Tensorflow 对象检测 api 中评估预训练模型

Tensorflow 对象检测 API 中的过拟合

如何为 tensorflow 对象检测模型运行 eval.py 作业