有啥方法可以在 Spark ML Pipeline 中序列化自定义 Transformer
Posted
技术标签:
【中文标题】有啥方法可以在 Spark ML Pipeline 中序列化自定义 Transformer【英文标题】:Is there any means to serialize custom Transformer in Spark ML Pipeline有什么方法可以在 Spark ML Pipeline 中序列化自定义 Transformer 【发布时间】:2016-10-27 12:08:18 【问题描述】:我将 ML 管道与各种基于 UDF 的自定义转换器一起使用。我正在寻找的是一种序列化/反序列化此管道的方法。
我使用
序列化 PipelineModelObjectOutputStream.write()
但是,每当我尝试反序列化我拥有的管道时:
java.lang.ClassNotFoundException: org.sparkexample.DateTransformer
DateTransformer 在哪里是我的自定义转换器。是否有任何方法/接口可以实现正确的序列化?
我发现有
MLWritable
我的类可能实现的接口(DateTransformer 扩展了 Transfrormer),但是找不到有用的示例。
【问题讨论】:
【参考方案1】:如果您使用的是 Spark 2.x+,则使用 DefaultParamsWritable 扩展您的转换器
例如
class ProbabilityMaxer extends Transformer with DefaultParamsWritable
然后创建一个带字符串参数的构造函数
def this(_uid: String)
this()
最后为成功读取添加一个伴生类
object ProbabilityMaxer extends DefaultParamsReadable[ProbabilityMaxer]
我的生产服务器上有这个工作。稍后我会在上传时将gitlab链接添加到项目中
【讨论】:
【参考方案2】:简短的回答是你不能,至少不容易。
开发人员竭尽全力让添加新的转换器/估算器变得尽可能困难。基本上org.apache.spark.ml.util.ReadWrite
中的所有内容都是私有的(MLWritable
和MLReadable
除外),因此无法使用其中的任何实用程序方法/类/对象。还有(我相信你已经发现了)绝对没有关于如何做到这一点的文档,但是好的代码文档本身对吗?!
挖掘org.apache.spark.ml.util.ReadWrite
和org.apache.spark.ml.feature.HashingTF
中的代码,您似乎需要覆盖MLWritable.write
和MLReadable.read
。似乎包含实际保存/加载实现的DefaultParamsWriter
和DefaultParamsReader
正在保存和加载一堆元数据:
所以任何实现都至少需要涵盖这些,而不需要学习任何模型的转换器可能会摆脱这些。需要拟合的模型还需要在 save/write
的实现中保存该数据 - 例如这是 LocalLDAModel
所做的 (https://github.com/apache/spark/blob/v1.6.3/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala#L523) 所以学习模型只是保存为镶木地板文件(它好像)
val data = sqlContext.read.parquet(dataPath)
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
"gammaShape")
.head()
作为测试,我从org.apache.spark.ml.util.ReadWrite
复制了似乎需要的所有内容,并测试了以下转换器它没有做任何有用的事情。
警告:这几乎可以肯定是错误的做法,并且很可能在未来被打破。 我真诚地希望我误解了一些东西,并且有人会纠正我如何实际创建一个可以序列化/反序列化的转换器
这适用于 spark 1.6.3,如果您使用的是 2.x,可能已经损坏
import org.apache.spark.sql.types._
import org.apache.spark.ml.param._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.Identifiable, MLReadable, MLReader, MLWritable, MLWriter
import org.apache.spark.sql.SQLContext, DataFrame
import org.apache.spark.mllib.linalg._
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
object CustomTransform extends DefaultParamsReadable[CustomTransform]
/* Companion object for deserialisation */
override def load(path: String): CustomTransform = super.load(path)
class CustomTransform(override val uid: String)
extends Transformer with DefaultParamsWritable
def this() = this(Identifiable.randomUID("customThing"))
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
def getOutputCol(): String = getOrDefault(outputCol)
val inputCol = new Param[String](this, "inputCol", "input column")
val outputCol = new Param[String](this, "outputCol", "output column")
override def transform(dataset: DataFrame): DataFrame =
val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
import sqlContext.implicits._
val outCol = extractParamMap.getOrElse(outputCol, "output")
val inCol = extractParamMap.getOrElse(inputCol, "input")
val transformUDF = udf( vector: SparseVector =>
vector.values.map( _ * 10 )
// WHAT EVER YOUR TRANSFORMER NEEDS TO DO GOES HERE
)
dataset.withColumn(outCol, transformUDF(col(inCol)))
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType =
val outputFields = schema.fields :+ StructField(extractParamMap.getOrElse(outputCol, "filtered"), new VectorUDT, nullable = false)
StructType(outputFields)
然后我们需要来自org.apache.spark.ml.util.ReadWrite
的所有实用程序
https://github.com/apache/spark/blob/v1.6.3/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
trait DefaultParamsWritable extends MLWritable self: Params =>
override def write: MLWriter = new DefaultParamsWriter(this)
trait DefaultParamsReadable[T] extends MLReadable[T]
override def read: MLReader[T] = new DefaultParamsReader
class DefaultParamsWriter(instance: Params) extends MLWriter
override protected def saveImpl(path: String): Unit =
DefaultParamsWriter.saveMetadata(instance, path, sc)
object DefaultParamsWriter
/**
* Saves metadata + Params to: path + "/metadata"
* - class
* - timestamp
* - sparkVersion
* - uid
* - paramMap
* - (optionally, extra metadata)
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit =
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render(params.map case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
.toList))
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadata = extraMetadata match
case Some(jObject) =>
basicMetadata ~ jObject
case None =>
basicMetadata
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
class DefaultParamsReader[T] extends MLReader[T]
override def load(path: String): T =
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
DefaultParamsReader.getAndSetParams(instance, metadata)
instance.asInstanceOf[T]
object DefaultParamsReader
/**
* All info from metadata file.
*
* @param params paramMap, as a [[JValue]]
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
*/
case class Metadata(
className: String,
uid: String,
timestamp: Long,
sparkVersion: String,
params: JValue,
metadata: JValue,
metadataJson: String)
/**
* Load metadata from file.
*
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata =
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
if (expectedClassName.nonEmpty)
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
/**
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit =
implicit val format = DefaultFormats
metadata.params match
case JObject(pairs) =>
pairs.foreach case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: $metadata.metadataJson.")
/**
* Load a [[Params]] instance from the given path, and return it.
* This assumes the instance implements [[MLReadable]].
*/
def loadParamsInstance[T](path: String, sc: SparkContext): T =
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
有了它,您可以在Pipeline
中使用CustomTransformer
并保存/加载管道。我在 spark shell 中很快测试了它,它似乎可以工作,但肯定不是很漂亮。
【讨论】:
以上是关于有啥方法可以在 Spark ML Pipeline 中序列化自定义 Transformer的主要内容,如果未能解决你的问题,请参考以下文章
Spark ml pipeline - transforming feature - StringIndexer