PySpark 中的分层交叉验证

Posted

技术标签:

【中文标题】PySpark 中的分层交叉验证【英文标题】:Stratified cross-validation in PySpark 【发布时间】:2020-08-31 20:36:02 【问题描述】:

我在 python 中使用 Apache Spark API,PySpark(--version 3.0.0),并且由于我的数据高度不平衡,我希望以分层的方式对我的标记数据执行交叉验证!我目前正在使用以下模块。

from pyspark.ml.tuning import CrossValidator

在 scikit-learn 中,这可以通过定义 StratifiedKFold 并将其放在任何网格搜索函数的 cv 参数中来实现。这确保了对给定估计器进行训练的每 K 折,都以具有比例代表性的方式包含标记数据。

PySpark 中是否有类似的功能?

我很想向 Spark 团队提出这个问题或作为改进建议,但他们的 GitHub page 不允许错误报告或改进建议,他们的 official page 需要随附的源代码建议,这略高于我的技能! 咆哮>

【问题讨论】:

【参考方案1】:

我认为 Spark ML 目前不支持分层交叉验证。

不过,您可以查看spark-stratifier。这是一个 Spark ML 分层交叉验证器组件,几年前由 HackerRank 开源 [1]。你可以看看那个。

运行pip install spark-stratifier即可安装。

一些示例代码可以帮助您:

from pyspark.ml import Pipeline
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.sql import SparkSession

from spark_stratifier import StratifiedCrossValidator

spark = SparkSession.builder.master("local").getOrCreate()

data = spark.createDataFrame([
    (1.0, Vectors.dense([0.0, 1.1, 0.1])),
    (0.0, Vectors.dense([2.0, 1.0, -1.0])),
    (0.0, Vectors.dense([2.0, 1.3, 1.0])),
    (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"])

lr = LogisticRegression(maxIter=10)

pipeline = Pipeline(stages=[lr])

paramGrid = ParamGridBuilder() \
    .addGrid(lr.regParam, [0.1, 0.01]) \
    .build()

scv = StratifiedCrossValidator(
        estimator=pipeline,
        estimatorParamMaps=paramGrid,
        evaluator=BinaryClassificationEvaluator(),
        numFolds=2
      )

model = scv.fit(data)

[1]:https://github.com/interviewstreet/spark-stratifier

【讨论】:

由于我已经在 PySpark 3.0.0 中编写了代码并安装了所有相关的依赖项,spark-stratifier 似乎并不喜欢它,并引发了与Py4JError 相关的冗长错误。当我重新安装 PySpark 3.0.0 时它也不起作用

以上是关于PySpark 中的分层交叉验证的主要内容,如果未能解决你的问题,请参考以下文章

Scikit Learn 分层交叉验证中的差异

scikit-learn中的随机分层k折交叉验证?

R中的分层k倍交叉验证

基于 python 中的多个特征的训练测试拆分的分层交叉验证或抽样

Pyspark 线性回归梯度下降交叉验证

Scikit-Learn 中的分层标记 K 折交叉验证