尝试向检测模型“ssd_mobilenet_v2”添加警报,会引发错误

Posted

技术标签:

【中文标题】尝试向检测模型“ssd_mobilenet_v2”添加警报,会引发错误【英文标题】:Trying to add an alert to detection model "ssd_mobilenet_v2", throws an error 【发布时间】:2022-01-09 11:40:00 【问题描述】:

我正在尝试使用“ssd_mobilenet_v2_fpn_keras”添加警报系统

检测模型加载到下面的函数中

 def detect_fn(image):
    image, shapes = detection_model.preprocess(image)
    prediction_dict = detection_model.predict(image, shapes)
    detections = detection_model.postprocess(prediction_dict, shapes)
    return detections

图像转换为张量

input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)

张量被馈送到检测模型

detections = detect_fn(input_tensor)

检测模型的输出是一个字典,具有以下键:

dict_keys(['detection_boxes', 'detection_scores', 'detection_classes', 'raw_detection_boxes', 'raw_detection_scores', 'detection_multiclass_scores', 'detection_anchor_indices', 'num_detections'])

detections[detection_classes],给出以下输出,即 0 是 ClassA,1 是 ClassB

[0 1 1 0 0 1 0 0 1 0 1 1 0 0 1 0 1 1 0 1 0 1 1 0 0 1 0 0 1 0 1 0 0 1 1 1 1 0 0 0 1 1 1 0 0 1 1 1 0 1 0 1 0 0 0 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0 0 0 1 0 1 1 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 1 0 1]

detections['detection_scores'] 给出检测到的每个框的分数(下面显示了几个)

[0.988446 0.7998712 0.1579772 0.13801616 0.13227147 0.12731305 0.09515342 0.09203091 0.09191579 0.08860824 0.08313078 0.07684237

我正在尝试Print("Attention needed"),如果观察到检测类 B 即 1

for key in detections['detection_classes']:
if key==1:
    print('Alert')

当我尝试这样做时,我得到一个错误

`ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

如何让它发挥作用?

我希望代码打印“需要注意”是 Class =1 或 A 并且 detection_scores >= 14

Code Explained, a bit further


完整代码的链接如下:

Tutorial on YouTube GitHub sources repository

【问题讨论】:

【参考方案1】:

如错误消息中所述,您应该使用.any()。喜欢:

if (key == 1).any():
  print('Alert')

因为key == 1 将是一个带有[False, True, True, False, ...] 的数组

您可能还想检测超过特定分数的分数,例如 0.7:

for key, score in zip(
  detections['detection_classes'],
  detections['detection_scores']):
  if score > 0.7 and key == 1:
    print('Alert')
    break

【讨论】:

以上是关于尝试向检测模型“ssd_mobilenet_v2”添加警报,会引发错误的主要内容,如果未能解决你的问题,请参考以下文章

将对象检测模型嵌入 iOS 应用程序,并将其部署在 UIView 的内容而不是相机流上?

在树莓派 (tensorflow) 上进行对象检测的项目

如何修改 Tensorflow 2.0 中的 epoch 数?

Tensorflow 图节点是交换的

如何向预训练的对象检测模型添加其他类并训练它以检测所有类(预训练 + 新)?

基于YOLO的手部检测和计数实现(课程设计,训练测试+模型剪枝+模型压缩)