如何在 Tensorflow 中正确使用 tf.metrics.mean_iou 在 Tensorboard 上显示混淆矩阵?

Posted

技术标签:

【中文标题】如何在 Tensorflow 中正确使用 tf.metrics.mean_iou 在 Tensorboard 上显示混淆矩阵?【英文标题】:How to properly use tf.metrics.mean_iou in Tensorflow to show confusion matrix on Tensorboard? 【发布时间】:2019-10-10 14:51:16 【问题描述】:

我发现在 DeeplabV3+ (eval.py) 的 Tensorflow 官方实现中的评估脚本使用 tf.metrics.mean_iou 更新平均 IOU,并将其添加到 Tensorboard 记录。

tf.metrics.mean_iou实际上返回2个张量,一个是计算的mean IOU,另一个是opdate_op,根据官方文档(doc),混淆矩阵。似乎每次要计算 mean_iou 时,都必须先调用 update_op。

我正在尝试将此 update_op 作为张量添加到摘要中,但它不起作用。我的问题是如何将这个混淆矩阵添加到 Tensorboard 中?

我看到了一些关于如何计算混淆矩阵并将其添加到 Tensorboard 的其他线程,以及额外的操作。我只是想知道如果没有这些额外的操作,是否可以做到这一点。

任何帮助将不胜感激。

【问题讨论】:

【参考方案1】:

我会在这里发布我的答案,因为有人支持它。

假设您以下列方式定义了mean_iou op:

    miou, update_op = tf.metrics.mean_iou(
        predictions, labels, dataset.num_of_classes, weights=weights)
    tf.summary.scalar(predictions_tag, miou)

如果你在 Tensorboard 中查看你的图表,你会发现有一个名为 'mean_iou' 的节点,展开这个节点后,你会发现有一个名为 'total_confucion_matrix' 的操作。这是您计算每个类的召回率和精度所需要的。

获取节点名称后,您可以通过tf.summary.text 将其添加到您的tensorboard 或通过tf.print 函数在您的终端中打印。下面贴出一个例子:

    miou, update_op = tf.metrics.mean_iou(
        predictions, labels, dataset.num_of_classes, weights=weights)
    tf.summary.scalar(predictions_tag, miou)
    # Get the correct tensor name of confusion matrix, different graphs may vary
    confusion_matrix = tf.get_default_graph().get_tensor_by_name('mean_iou/total_confusion_matrix:0')

    # Calculate precision and recall matrix
    precision = confusion_matrix / tf.reshape(tf.reduce_sum(confusion_matrix, 1), [-1, 1])
    recall = confusion_matrix / tf.reshape(tf.reduce_sum(confusion_matrix, 0), [-1, 1])

    # Print precision, recall and miou in terminal
    precision_op = tf.print("Precision:\n", precision,
                         output_stream=sys.stdout)
    recall_op = tf.print("Recall:\n", recall,
                         output_stream=sys.stdout)
    miou_op = tf.print("Miou:\n", miou,
                         output_stream=sys.stdout)

    # Add precision and recall matrix in Tensorboard
    tf.summary.text('recall_matrix', tf.dtypes.as_string(recall, precision=4))
    tf.summary.text('precision_matrix', tf.dtypes.as_string(precision, precision=4))

    # Create summary hooks
    summary_op = tf.summary.merge_all()
    summary_hook = tf.contrib.training.SummaryAtEndHook(
        log_dir=FLAGS.eval_logdir, summary_op=summary_op)
    precision_op_hook = tf.train.FinalOpsHook(precision_op)
    recall_op_hook = tf.train.FinalOpsHook(recall_op)
    miou_op_hook = tf.train.FinalOpsHook(miou_op)
    hooks = [summary_hook, precision_op_hook, recall_op_hook, miou_op_hook]

    num_eval_iters = None
    if FLAGS.max_number_of_evaluations > 0:
      num_eval_iters = FLAGS.max_number_of_evaluations

    if FLAGS.quantize_delay_step >= 0:
      tf.contrib.quantize.create_eval_graph()

    tf.contrib.training.evaluate_repeatedly(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.checkpoint_dir,
        eval_ops=[update_op],
        max_number_of_evaluations=num_eval_iters,
        hooks=hooks,
        eval_interval_secs=FLAGS.eval_interval_secs)

然后,您将在 Tensorboard 中汇总准确率和召回率矩阵:

【讨论】:

以上是关于如何在 Tensorflow 中正确使用 tf.metrics.mean_iou 在 Tensorboard 上显示混淆矩阵?的主要内容,如果未能解决你的问题,请参考以下文章

Tensorflow:如何正确使用 Adam 优化器

如何在 Tensorflow 中正确使用 tf.metrics.mean_iou 在 Tensorboard 上显示混淆矩阵?

如何正确使用带有 C++ 的 tensorflow 从 YOLO 模型中获取输出?

如何在使用 tensorflow lite Android API 第一次正确图像检测后停止分类

如何正确地重新标记 TensorFlow 数据集?

如何在 Tensorflow 中计算 R^2