如果不使用spark-ml中的管道,交叉验证会更快吗?

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如果不使用spark-ml中的管道,交叉验证会更快吗?相关的知识,希望对你有一定的参考价值。

假设我的功能工程中有很多步骤:我的管道中会有很多变换器。我想知道Spark在管道的交叉验证过程中如何处理这些变换器:它们是针对每个折叠执行的吗?在交叉验证模型之前应用变压器会更快吗?

哪个工作流程最快(或者有更好的解决方案)?:

1. Cross validator on pipeline

transformer1 = ...
transformer2 = ...
transformer3 = ...
lr = LogisticRegression(...)
pipeline = Pipeline(stages=[transformer1, transformer2, transformer3, lr])
crossval = CrossValidator(estimator=pipeline, numFolds=10, ...)

cvModel = crossval.fit(training)
prediction = cvModel.transform(test)

2. Cross validator after pipeline

transformer1 = ...
transformer2 = ...
transformer3 = ...
pipeline = Pipeline(stages=[transformer1, transformer2, transformer3])
training_trans = pipeline.fit(training).transform(training)

lr = LogisticRegression(...)
crossval = CrossValidator(estimator=lr, numFolds=10, ...)

cvModel = crossval.fit(training_trans)
prediction = cvModel.transform(test)

最后,我对使用缓存有同样的问题:在2.我可以在进行交叉验证之前缓存training_trans。在1.我可以在LogisticRegression之前在管道中使用Cacher变换器。 (参见Caching intermediate results in Spark ML pipeline for the Cacher)

答案

我已经做了实验,但我仍然感兴趣,如果有人能给出更详细的答案。

%%time
pipeline1 = Pipeline(stages=stringIndexers+oneHotEncoders+[vectorAssembler])
train2 = pipeline1.fit(train).transform(train)
crossval = CrossValidator(estimator=logisticRegression, ...)
crossval.fit(train2)

CPU时间:用户508毫秒,系统:136毫秒,总计:644毫秒/壁挂时间:2分钟2秒

%%time
pipeline1 = Pipeline(stages=stringIndexers+oneHotEncoders+[vectorAssembler])
train2 = pipeline1.fit(train).transform(train)
train2.cache().count()
crossval = CrossValidator(estimator=logisticRegression, ...)
crossval.fit(train2)

CPU时间:用户560毫秒,系统:104毫秒,总计:664毫秒/挂起时间:1分钟25秒

%%time
pipeline2 = Pipeline(stages=stringIndexers+oneHotEncoders+[vectorAssembler, logisticRegression])
crossval = CrossValidator(estimator=pipeline2, ...)
crossval.fit(train)

CPU时间:用户2.06秒,系统:504毫秒,总计:2.56秒/挂机时间:3分钟

以上是关于如果不使用spark-ml中的管道,交叉验证会更快吗?的主要内容,如果未能解决你的问题,请参考以下文章

如何将交叉验证目标输入管道中的自定义转换器

如果我在 python 管道中有自定义的集成模型,如何进行交叉验证和网格搜索

交叉验证管道的分类报告

spark-ml 规范化器丢失元数据

如何使用 Sklearn 管道进行参数调整/交叉验证?

KaggleIntermediate Machine Learning(管道+交叉验证)