从 Spark ML LinearSVC 解释 rawPrediction

Posted

技术标签:

【中文标题】从 Spark ML LinearSVC 解释 rawPrediction【英文标题】:Interpreting rawPrediction from Spark ML LinearSVC 【发布时间】:2018-12-22 05:44:10 【问题描述】:

我在二元分类模型中使用 Spark ML 的 LinearSVC。 transform 方法创建两列,predictionrawPrediction。 Spark 的文档没有提供任何解释此特定分类器的 rawPrediction 列的方式。其他分类器已经提出并回答了这个问题,但不是专门针对 LinearSVC。

我的predictions 数据框中的相关列:

+------------------------------------------+ 
|rawPrediction                             | 
+------------------------------------------+ 
|[0.8553257800650063,-0.8553257800650063]  | 
|[0.4230977574196645,-0.4230977574196645]  | 
|[0.49814263303537865,-0.49814263303537865]| 
|[0.9506355050332026,-0.9506355050332026]  | 
|[0.5826887000450813,-0.5826887000450813]  | 
|[1.057222808292026,-1.057222808292026]    | 
|[0.5744214192446275,-0.5744214192446275]  | 
|[0.8738081933835614,-0.8738081933835614]  | 
|[1.418173816502859,-1.418173816502859]    | 
|[1.0854125533426737,-1.0854125533426737]  | 
+------------------------------------------+

显然,这不仅仅是属于每个类别的概率。这是什么?

编辑:由于已请求输入代码,因此这里有一个基于原始数据集中特征子集的模型。使用 Spark 的 LinearSVC 拟合任何数据都会生成此列。

var df = sqlContext
  .read
  .format("csv")
  .option("header", "true")
  .option("inferSchema", "true")
  .load("/FileStore/tables/full_frame_20180716.csv")


var assembler = new VectorAssembler()
  .setInputCols(Array("oy_length", "ah_length", "ey_length", "vay_length", "oh_length", 
                      "longest_word_length", "total_words", "repeated_exact_words",
                      "repeated_bigrams", "repeated_lemmatized_words", 
                      "repeated_lemma_bigrams"))
  .setOutputCol("features")

df = assembler.transform(df)

var Array(train, test) = df.randomSplit(Array(.8,.2), 42)

var supvec = new LinearSVC()
  .setLabelCol("written_before_2004")
  .setMaxIter(10)
  .setRegParam(0.001)

var supvecModel = supvec.fit(train)

var predictions = supvecModel.transform(test)

predictions.select("rawPrediction").show(20, false)

输出:

+----------------------------------------+ 
|rawPrediction | 
+----------------------------------------+ 
|[1.1502868455791242,-1.1502868455791242]| 
|[0.853488887006264,-0.853488887006264] | 
|[0.8064994501574174,-0.8064994501574174]| 
|[0.7919862003563363,-0.7919862003563363]| 
|[0.847418035176922,-0.847418035176922] | 
|[0.9157433788236442,-0.9157433788236442]| 
|[1.6290888181913814,-1.6290888181913814]| 
|[0.9402461917731906,-0.9402461917731906]| 
|[0.9744052798627367,-0.9744052798627367]| 
|[0.787542624053347,-0.787542624053347] | 
|[0.8750602657901001,-0.8750602657901001]| 
|[0.7949414037722276,-0.7949414037722276]| 
|[0.9163545832998052,-0.9163545832998052]| 
|[0.9875454213431247,-0.9875454213431247]| 
|[0.9193015302646135,-0.9193015302646135]| 
|[0.9828623328048487,-0.9828623328048487]| 
|[0.9175976004208621,-0.9175976004208621]| 
|[0.9608750388820302,-0.9608750388820302]| 
|[1.029326217566756,-1.029326217566756] | 
|[1.0190290910146256,-1.0190290910146256]| +----------------------------------------+ 
only showing top 20 rows

【问题讨论】:

发布代码和一些返回此类输出的数据 @VivekKumar 我知道最好有关于主题的问题,我也鼓励有人这样做,但这是一个非常完整的问题,具有清晰准确的问题定义。建议更好的堆栈,例如数据科学或交叉验证,而不是投反对票、标记和持有。 @jonwhithers datascience.stackexchange.com @sconfluentus 当我投反对票和标记时,这是因为问题没有重现输出所需的代码。为什么你认为数据科学是最适合做这件事的地方? @VivekKumar 因为问题不在于代码,而在于模型的输出,该模型存在于第一个化身中。看到代码不会使输出更容易理解,因此更适合 CrossValidated 或 DataScience 堆栈,人们期待建模问题......无需提供数据和代码来生成输出以供线性回归解释如果存在系数表,则为 p 值,并且也无需查看用于描述 spark 的 rawPrediction 格式的数据……前提是您知道该模型的工作原理。 【参考方案1】:

(-margin, margin)

override protected def predictRaw(features: Vector): Vector = 
    val m = margin(features)
    Vectors.dense(-m, m)
  

【讨论】:

【参考方案2】:

正如arpad所说,是margin。

而边距是:

      margin = coefficients * feature + intercept    
                            or
                     y = w * x + b

如果你将边距除以系数的范数,你将得到每个数据点到超平面的距离。

【讨论】:

以上是关于从 Spark ML LinearSVC 解释 rawPrediction的主要内容,如果未能解决你的问题,请参考以下文章

Spark:从管道模型中提取 ML 逻辑回归模型的摘要

Spark|ML|随机森林|从 RandomForestClassificationModel 的 .txt 加载训练模型。 toDebugString

如何从 PySpark 中的 spark.ml 中提取模型超参数?

Spark ml 和 PMML 导出

我通过使用它的 pyspark.ml.regression.LinearRegression 在 spark 中创建一个模型

spark mllib和ml类里面的区别