PySpark ML:LinearSVC 的 OnevsRest 策略
Posted
技术标签:
【中文标题】PySpark ML:LinearSVC 的 OnevsRest 策略【英文标题】:PySpark ML: OnevsRest strategy for LinearSVC 【发布时间】:2018-05-14 20:46:09 【问题描述】:我是 PySpark 的新手。我在 Windows 10 上安装了 Spark 2.3.0 。 我想使用线性 SVM 分类器进行交叉验证训练,但用于具有 3 个类的数据集。所以我正在尝试从 Spark ML 应用 One vs Rest 策略。但我的代码似乎有问题,因为我收到一个错误,表明 LinearSVC 用于二进制分类。
这是我在调试时尝试执行“crossval.fit”行时出现的错误:
pyspark.sql.utils.IllegalArgumentException: u'requirement failed: LinearSVC only supports binary classification. 1 classes detected in LinearSVC_43a48b0b70d59a8cbdb1__labelCol'
这是我的代码: (我正在尝试仅包含 10 个实例的非常小的数据集)
from pyspark import SparkContext
sc = SparkContext('local', 'my app')
from pyspark.ml.linalg import Vectors
from pyspark import SQLContext
sqlContext = SQLContext(sc)
import numpy as np
x_train=np.array([[1,2,3],[5,6,7],[9,10,11],[2,4,5],[2,7,9],[3,7,6],[8,3,6],[5,8,2],[44,11,55],[77,33,22]])
y_train=[1,0,2,1,0,2,1,0,2,1]
#converting numpy array to dataframe
df_list = []
i = 0
for element in x_train: # row
tup = (y_train[i], Vectors.dense(element))
i = i + 1
df_list.append(tup)
Train_sparkframe = sqlContext.createDataFrame(df_list, schema=['label', 'features'])
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import OneVsRest
from pyspark.ml.classification import LinearSVC
LSVC = LinearSVC()
ovr = OneVsRest(classifier=LSVC)
paramGrid = ParamGridBuilder().addGrid(LSVC.maxIter, [10, 100]).addGrid(LSVC.regParam,
[0.001, 0.01, 1.0,10.0]).build()
crossval = CrossValidator(estimator=ovr,
estimatorParamMaps=paramGrid,
evaluator=MulticlassClassificationEvaluator(metricName="f1"),
numFolds=2)
cvModel = crossval.fit(Train_sparkframe)
bestModel = cvModel.bestModel
【问题讨论】:
这应该不是问题了。OneVsRest
类现在也应该支持 LinearSVC
。请再次检查。
【参考方案1】:
正如documentation 所说:
注意现在只支持 LogisticRegression 和 NaiveBayes。
【讨论】:
【参考方案2】:我能够在此 IBM 笔记本中的 Python 3.5/Spark 2.3 托管环境中有效地重现您的代码,而不会出现问题:https://eu-gb.dataplatform.cloud.ibm.com/analytics/notebooks/v2/24bb87d9-d28b-433b-b85a-5a86f4d0b56b/view?access_token=3c7bec3ed89bb518357fcce8005874a66a1d65833e997603141632b5cbb484db
由于云环境为您管理 Spark 上下文,我建议您查看您的 Spark 设置并仔细检查您的列命名。
【讨论】:
以上是关于PySpark ML:LinearSVC 的 OnevsRest 策略的主要内容,如果未能解决你的问题,请参考以下文章
通过 pyspark.ml CrossValidator 调整隐式 pyspark.ml ALS 矩阵分解模型的参数