如何在 PySpark 中使用 StandardScaler 标准化测试数据集?

Posted

技术标签:

【中文标题】如何在 PySpark 中使用 StandardScaler 标准化测试数据集?【英文标题】:how do I standardize test dataset using StandardScaler in PySpark? 【发布时间】:2021-01-02 07:25:34 【问题描述】:

我有如下训练和测试数据集:

x_train:

inputs
[2,5,10]
[4,6,12]
...

x_test:

inputs
[7,8,14]
[5,5,7]
...

输入列是在将 VectorAssembler 类应用于 3 个单独的列之后包含模型特征的向量。

当我尝试使用下面的 StandardScaler 转换测试数据时,我收到一条错误消息,指出它没有转换方法:

from pyspark.ml.feature import StandardScaler 
scaler = StandardScaler(inputCol="inputs", outputCol="scaled_features")
scaledTrainDF = scaler.fit(x_train).transform(x_train)
scaledTestDF = scaler.transform(x_test)

有人告诉我,我应该只在训练数据上拟合一次标准缩放器,并使用这些参数来转换测试集,所以这样做是不准确的:

scaledTestDF = scaler.fit(x_test).transform(x_test)

那么我该如何处理上面提到的错误呢?

【问题讨论】:

【参考方案1】:

这是使用缩放器的正确语法。您需要在拟合模型上调用变换,而不是在缩放器本身上。

from pyspark.ml.feature import StandardScaler 
scaler = StandardScaler(inputCol="inputs", outputCol="scaled_features")
scaler_model = scaler.fit(x_train)

scaledTrainDF = scaler_model.transform(x_train)
scaledTestDF = scaler_model.transform(x_test)

【讨论】:

以上是关于如何在 PySpark 中使用 StandardScaler 标准化测试数据集?的主要内容,如果未能解决你的问题,请参考以下文章

ECSHOP安装或使用中提示Strict Standards: Non-static method cls_image:

爱创课堂每日一题第二天8/24日 Quirks模式是什么?它和Standards模式有什么区别?

php中出现Strict Standards: Only variables should be passed by reference in的解决方法

Apache Spark:如何在Python 3中使用pyspark

PySpark:如何在列中使用 Or 进行分组

如何在 Pyspark 中使用 Scala 函数?