Tensorflow v1 对象检测 api mask_rcnn_inception_v2_coco 模型批量推理

Posted

技术标签:

【中文标题】Tensorflow v1 对象检测 api mask_rcnn_inception_v2_coco 模型批量推理【英文标题】:Tensorflow v1 object detection api mask_rcnn_inception_v2_coco model batch inferencing 【发布时间】:2020-11-23 06:19:31 【问题描述】:

我尝试使用 tensorflow 对象检测 api 中可用的 mask_rcnn_inception_v2_coco 模型进行分割任务。在这里,我想对视频进行推理。

首先我进行了推理,一次一帧(批量大小 =1)。 然后性能(每秒帧数)非常低。为了获得更好的“fps”值,我尝试更改代码以支持批量推理。

当我使用 batch_size >1 运行代码时,出现以下错误。

我已将相同的批量推理代码用于其他模型,例如 ssd_inception_v2_coco(它们是对象检测模型),它们运行时没有任何问题。

这是否意味着 ma​​sk_rcnn_inception_v2_coco 不支持批处理? 或者这是另一个问题?

box_ind is deprecated, use box_indices instead
Traceback (most recent call last):
  File "/home/omega/anaconda3/envs/tf_gpu_env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1365, in _do_call
    return fn(*args)
  File "/home/omega/anaconda3/envs/tf_gpu_env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _run_fn
    target_list, run_metadata)
  File "/home/omega/anaconda3/envs/tf_gpu_env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1443, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: Can not squeeze dim[0], expected a dimension of 1, got 2
     [[node Squeeze_4]]
     [[detection_boxes/_333]]
  (1) Invalid argument: Can not squeeze dim[0], expected a dimension of 1, got 2
     [[node Squeeze_4]]
0 successful operations.
0 derived errors ignored.

【问题讨论】:

【参考方案1】:

我遇到了同样的问题。我明白原因,但暂时不知道如何解决。

至于遮罩部分,默认图表中还有一个部分,用于将图像重新构图为原始形状。

但是 tf.utils.op 中提供的功能目前只接受单张图片。

您可以看到没有批处理的代码,它接受 box_masks 作为大小为 [num_masks, mask_height, mask_width] 的张量。

def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
                                     image_width):
    """Transforms the box masks back to full image masks.

    Embeds masks in bounding boxes of larger masks whose shapes correspond to
    image shape.

    Args:
      box_masks: A tf.float32 tensor of size [num_masks, mask_height, mask_width].
      boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
             corners. Row i contains [ymin, xmin, ymax, xmax] of the box
             corresponding to mask i. Note that the box corners are in
             normalized coordinates.
      image_height: Image height. The output mask will have the same height as
                    the image height.
      image_width: Image width. The output mask will have the same width as the
                   image width.

    Returns:
      A tf.float32 tensor of size [num_masks, image_height, image_width].
    """

    """# TODO(rathodv): Make this a public function."""

    def reframe_box_masks_to_image_masks_default():
        """The default function when there are more than 0 box masks."""

        def transform_boxes_relative_to_boxes(boxes, reference_boxes):
            boxes = tf.reshape(boxes, [-1, 2, 2])
            min_corner = tf.expand_dims(reference_boxes[:, 0:2], 1)
            max_corner = tf.expand_dims(reference_boxes[:, 2:4], 1)
            transformed_boxes = (boxes - min_corner) / (max_corner - min_corner)
            return tf.reshape(transformed_boxes, [-1, 4])

        box_masks_expanded = tf.expand_dims(box_masks, axis=3)
        num_boxes = tf.shape(box_masks_expanded)[0]
        unit_boxes = tf.concat(
            [tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1)
        reverse_boxes = transform_boxes_relative_to_boxes(unit_boxes, boxes)
        return tf.image.crop_and_resize(
            image=box_masks_expanded,
            boxes=reverse_boxes,
            box_ind=tf.range(num_boxes),
            crop_size=[image_height, image_width],
            extrapolation_value=0.0)

    image_masks = tf.cond(
        tf.shape(box_masks)[0] > 0,
        reframe_box_masks_to_image_masks_default,
        lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32))
    return tf.squeeze(image_masks, axis=3)

【讨论】:

以上是关于Tensorflow v1 对象检测 api mask_rcnn_inception_v2_coco 模型批量推理的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow 对象检测 Api M1 Macbook 冲突错误

Tensorflow对象检测:为啥使用ssd mobilnet v1时图像中的位置会影响检测精度?

TensorFlow 对象检测 API - 对象检测 api 中的损失意味着啥?

TensorFlow 对象检测 API 中未检测到任何内容

具有奇怪检测结果的 TensorFlow 对象检测 api

Tensorflow 对象检测 API 中的过拟合