在 TensorFlow 对象检测 API 中打印类名和分数
Posted
技术标签:
【中文标题】在 TensorFlow 对象检测 API 中打印类名和分数【英文标题】:Printing class name and score in Tensorflow Object Detection API 【发布时间】:2018-05-17 13:35:22 【问题描述】:我正在使用 Tensorflow 对象检测 API,一切正常,但我想打印一个格式为 Object name , Score 或类似格式的 dict 或数组,我需要的只是对象名称和分数。
我尝试了以下代码:
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
for image_path in TEST_IMAGE_PATHS:
image = Image.open(image_path)
# the array based representation of the image will be used later in order to prepare the
# result image with boxes and labels on it.
image_np = load_image_into_numpy_array(image)
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Actual detection.
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict=image_tensor: image_np_expanded)
print ([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5])
threshold = 0.5 # in order to get higher percentages you need to lower this number; usually at 0.01 you get 100% predicted objects
print(len(np.where(scores[0] > threshold)[0])/num_detections[0])
这部分工作正常
print ([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5])
它正在打印 ['name': 'computer', 'id': 1] 他们有什么办法可以将该对象的分数添加到字典中吗??
我在他们使用的 *** 上看到了另一个问题:
threshold = 0.5 # in order to get higher percentages you need to lower this number; usually at 0.01 you get 100% predicted objects
print(len(np.where(scores[0] > threshold)[0])/num_detections[0])
这给了我 Tensor("truediv:0", dtype=float32) 但即使它有效也不够,因为我没有对象的名称。
谢谢
【问题讨论】:
【参考方案1】:所以这是对我有用的解决方案。 (如果您仍在寻找解决方案,那就是)
# The following code replaces the 'print ([category_index...' statement
objects = []
for index, value in enumerate(classes[0]):
object_dict =
if scores[0, index] > threshold:
object_dict[(category_index.get(value)).get('name').encode('utf8')] = \
scores[0, index]
objects.append(object_dict)
print objects
【讨论】:
以上是关于在 TensorFlow 对象检测 API 中打印类名和分数的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow 对象检测 API - 对象检测 api 中的损失意味着啥?
如何在 Tensorflow 对象检测 api 中评估预训练模型