尝试向检测模型“ssd_mobilenet_v2”添加警报,会引发错误
Posted
技术标签:
【中文标题】尝试向检测模型“ssd_mobilenet_v2”添加警报,会引发错误【英文标题】:Trying to add an alert to detection model "ssd_mobilenet_v2", throws an error 【发布时间】:2022-01-09 11:40:00 【问题描述】:我正在尝试使用“ssd_mobilenet_v2_fpn_keras”添加警报系统
检测模型加载到下面的函数中
def detect_fn(image):
image, shapes = detection_model.preprocess(image)
prediction_dict = detection_model.predict(image, shapes)
detections = detection_model.postprocess(prediction_dict, shapes)
return detections
图像转换为张量
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
张量被馈送到检测模型
detections = detect_fn(input_tensor)
检测模型的输出是一个字典,具有以下键:
dict_keys(['detection_boxes', 'detection_scores', 'detection_classes', 'raw_detection_boxes', 'raw_detection_scores', 'detection_multiclass_scores', 'detection_anchor_indices', 'num_detections'])
detections[detection_classes]
,给出以下输出,即 0 是 ClassA,1 是 ClassB
[0 1 1 0 0 1 0 0 1 0 1 1 0 0 1 0 1 1 0 1 0 1 1 0 0 1 0 0 1 0 1 0 0 1 1 1 1 0 0 0 1 1 1 0 0 1 1 1 0 1 0 1 0 0 0 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0 0 0 1 0 1 1 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 1 0 1]
detections['detection_scores']
给出检测到的每个框的分数(下面显示了几个)
[0.988446 0.7998712 0.1579772 0.13801616 0.13227147 0.12731305 0.09515342 0.09203091 0.09191579 0.08860824 0.08313078 0.07684237
我正在尝试Print("Attention needed")
,如果观察到检测类 B 即 1
for key in detections['detection_classes']:
if key==1:
print('Alert')
当我尝试这样做时,我得到一个错误
`ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
如何让它发挥作用?
我希望代码打印“需要注意”是 Class =1 或 A 并且 detection_scores >= 14
Code Explained, a bit further
完整代码的链接如下:
Tutorial on YouTube GitHub sources repository【问题讨论】:
【参考方案1】:如错误消息中所述,您应该使用.any()
。喜欢:
if (key == 1).any():
print('Alert')
因为key == 1
将是一个带有[False, True, True, False, ...]
的数组
您可能还想检测超过特定分数的分数,例如 0.7:
for key, score in zip(
detections['detection_classes'],
detections['detection_scores']):
if score > 0.7 and key == 1:
print('Alert')
break
【讨论】:
以上是关于尝试向检测模型“ssd_mobilenet_v2”添加警报,会引发错误的主要内容,如果未能解决你的问题,请参考以下文章
将对象检测模型嵌入 iOS 应用程序,并将其部署在 UIView 的内容而不是相机流上?
如何修改 Tensorflow 2.0 中的 epoch 数?