Spark ML - MulticlassClassificationEvaluator - 我们可以通过每个类标签获得精度/召回率吗?

Posted

技术标签:

【中文标题】Spark ML - MulticlassClassificationEvaluator - 我们可以通过每个类标签获得精度/召回率吗?【英文标题】:Spark ML - MulticlassClassificationEvaluator - can we get precision/recall by each class label? 【发布时间】:2017-05-12 03:02:14 【问题描述】:

我正在 Spark ML 中使用随机森林进行多类预测。

对于 spark ML 中的这个 MulticlassClassificationEvaluator(),是否可以通过每个类标签获得精度/召回率?

目前,我只看到所有类的精度/召回率结合在一起。

【问题讨论】:

精确/召回也从最新版本的 spark [2.3.0] 中删除 【参考方案1】:

直接使用org.apache.spark.mllib.evaluation.MulticlassMetrics,然后获取可用的指标-

// copied from spark git
val predictionAndLabels =
      dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map 
        case Row(prediction: Double, label: Double) => (prediction, label)
      
    val metrics = new MulticlassMetrics(predictionAndLabels)

【讨论】:

【参考方案2】:

看看class documentation,这似乎是不可能的,使用内置方法。

虽然不是您要查找的内容,但您可以在 metricName 方法中使用 weightedPrecisionweightedRecall。这至少可以解决班级不平衡问题。

【讨论】:

以上是关于Spark ML - MulticlassClassificationEvaluator - 我们可以通过每个类标签获得精度/召回率吗?的主要内容,如果未能解决你的问题,请参考以下文章

在 Spark 的 map 函数中运行 ML 算法

是否可以访问 spark.ml 管道中的估计器属性?

使用 Spark ML 时出现 VectorUDT 问题

Spark:从管道模型中提取 ML 逻辑回归模型的摘要

Spark|ML|随机森林|从 RandomForestClassificationModel 的 .txt 加载训练模型。 toDebugString

列特征必须是 org.apache.spark.ml.linalg.VectorUDT 类型