Pyspark:保存变压器
Posted
技术标签:
【中文标题】Pyspark:保存变压器【英文标题】:Pyspark: save transformers 【发布时间】:2016-03-21 16:32:23 【问题描述】:我正在使用 Pyspark 的一些转换器,例如 StringIndexer、StandardScaler 等。我首先将它们应用于训练集,然后我想使用相同的转换对象(StringIndexerModel、StandardScalerModel 的相同参数)以便将它们应用于测试集。因此,我正在寻找一种将这些转换函数保存为文件的方法。但是,我找不到任何相关的方法,只能使用诸如 LogisticRegression 之类的 ml 函数。你知道有什么可能的方法吗?谢谢。
【问题讨论】:
您能否澄清一下您使用的是 MLlib 还是 ML? 糟糕,很抱歉。我正在使用 pyspark.ml。 【参考方案1】:我找到了一个简单的解决方案。
将索引器模型保存到文件中(在 HDFS 上)。
writer = indexerModel._call_java("write")
writer.save("indexerModel")
从文件(保存在 HDFS 上)加载索引器模型。
indexer = StringIndexerModel._new_java_obj("org.apache.spark.ml.feature.StringIndexerModel.load", "indexerModel")
indexerModel = StringIndexerModel(indexer)
【讨论】:
【参考方案2】:StringIndexer 和 StandardScaler 的输出都是 RDD,因此您可以将模型直接保存到文件中,或者更可能的是,您可以保留结果以供以后计算。
要保存到 parquet 文件调用(您可能还需要附加架构)sqlContext.createDataFrame(string_indexed_rdd).write.parquet("indexer.parquet")
。然后,您需要编程在需要时从文件中加载此结果。
要坚持调用string_indexed_rdd.persist()
。这会将中间结果保存在内存中以供以后重用。如果内存有限,您也可以传递选项以保存到磁盘。
如果您只想保留模型本身,那么您会遇到 api 中现有的错误/缺失功能 (PR)。如果底层问题解决了,没有提供新方法,则需要手动调用一些底层方法来获取和设置模型参数。查看模型代码,您可以看到模型继承自一系列类,其中一个是Params
。这个类有extractParamMap
,它提取模型中使用的参数。然后,您可以以任何您希望持久化 python dicts 的方式保存它。然后您需要创建一个空模型对象,然后调用copy(saved_params)
将持久化的参数传递给对象。
按照这些思路应该可以工作:
def save_params(model, filename):
d = shelve.open(filename)
try:
return d.update(model.extractParamMap())
finally:
d.close()
def load_params(ModelClass, filename):
d = shelve.open(filename)
try:
return ModelClass().copy(dict(d))
finally:
d.close()
【讨论】:
我对保存转换后的数据不感兴趣,而是模型本身。例如,我想保存不是 RDD 的 StandardScalerModel。 啊,你的问题在这一点上并不清楚——我浏览了源代码并添加了我认为可以实现你想要的东西。 不幸的是,它似乎不起作用,因为参数映射中不包含任何参数。如果我只是检查我的 StringIndexerModel 的实例,它只会返回一个空数组 ([])。 嗯,没有代码或我自己的可调试实例来搞乱很难说什么是错的。你用的是什么版本的spark?我正在阅读 1.6 源代码。这些模型不继承自MLWritable
,因此目前没有可调用的本机方法。在保存之前尝试测试_transfer_params_from_java
,然后打印self._paramMap
——确保那里有东西。然后在设置self._paramMap
后尝试显式调用_transfer_params_to_java
。这些绕过了方法。除此之外,您还需要探索 spark 和 repo 中的实现代码。
我确实查看了源代码并尝试了您刚才描述的内容。显然,已经有一个关于它的拉取请求:github.com/apache/spark/pull/10419以上是关于Pyspark:保存变压器的主要内容,如果未能解决你的问题,请参考以下文章
如何将您安装的变压器保存到 blob 中,以便您的预测管道可以在 AML 服务中使用它?