用于 PySpark 的酸洗猴子补丁 Keras 模型

Posted

技术标签:

【中文标题】用于 PySpark 的酸洗猴子补丁 Keras 模型【英文标题】:Pickling monkey-patched Keras model for use in PySpark 【发布时间】:2018-04-24 16:54:16 【问题描述】:

我要实现的总体目标是将 Keras 模型发送给每个 spark worker,以便我可以在应用于 DataFrame 列的 UDF 中使用该模型。为此,Keras 模型需要是可腌制的。

似乎很多人通过猴子修补 Model 类成功地腌制了 keras 模型,如下面的链接所示:

http://zachmoshe.com/2017/04/03/pickling-keras-models.html

但是,我还没有看到任何如何与 Spark 一起执行此操作的示例。我的第一次尝试只是在驱动程序中运行 make_keras_picklable() 函数,它允许我在驱动程序中腌制和取消腌制模型,但我无法在 UDF 中腌制模型。

def make_keras_picklable():
    "Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"
    ...

make_keras_picklable()

model = Sequential() # etc etc

def score(case):
    ....
    score = model.predict(case)
    ...

def scoreUDF = udf(score, ArrayType(FloatType()))

我得到的错误表明在 UDF 中解开模型没有使用猴子补丁模型类。

AttributeError: 'Sequential' object has no attribute '_built'

看起来另一个用户在this SO post 中遇到了类似的错误,答案是“在每个工作人员上也运行make_keras_picklable()”。没有给出如何执行此操作的示例。

我的问题是:对所有工作人员调用make_keras_picklable() 的适当方式是什么?

我尝试使用broadcast()(见下文)但得到与上述相同的错误。

def make_keras_picklable():
    "Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"
    ...

make_keras_picklable()
spark.sparkContext.broadcast(make_keras_picklable())

model = Sequential() # etc etc

def score(case):
    ....
    score = model.predict(case)
    ...

def scoreUDF = udf(score, ArrayType(FloatType()))

【问题讨论】:

您可以尝试使用pandas_udf 并仅广播模型权重。然后在 pandas_udf 中构建模型并加载权重。 【参考方案1】:

Khaled Zaouk 在 Spark 用户邮件列表上通过建议将 make_keras_picklable() 更改为包装类来帮助我。效果很好!

import tempfile

import tensorflow as tf


class KerasModelWrapper:
    """Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"""

    def __init__(self, model):
        self.model = model

    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
            tf.keras.models.save_model(self.model, fd.name, overwrite=True)
            model_str = fd.read()
        d = "model_str": model_str
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
            fd.write(state["model_str"])
            fd.flush()
            self.model = tf.keras.models.load_model(fd.name)

当然,通过将其实现为 Keras 模型类的子类或 PySpark.ML 转换器/估计器,这可能会变得更加优雅。

【讨论】:

感谢分享。您能否扩展您的实施,我无法让这个解决方案发挥作用。 @MatthewJackson 如果您有更多关于为什么解决方案不适合您的详细信息,我很乐意尝试并提供帮助。据我记得,对于我的用例,该解决方案的工作原理与发布的完全一致。 我不明白这应该如何实现。我创建了 KerasModelWrapper(my_model) 类的一个实例,但是当我保存并加载包装模型并尝试使用它时,我收到错误说没有类方法预测。我正在尝试将包装模型用作 sklearn 管道的一部分。 @Erp12 你能分享一个最小可重现的例子吗?让它工作有一些困难。 @Erp12 我的问题更多的是在脚本上运行它的位置。在你训练之前?之后?加载时也使用它吗?由于我的 dl 架构中有自定义层,我无法运行。我知道我可以通过 load_model 中的 custom_objects 传递它们,但我也有错误。对我有用的唯一方法是将 df 转换为 rdd 并使用 mapPartitions,其中模型加载到每个分区中,因此无需序列化。但我真的很想通过 keras 模型的序列化来完成这项工作。【参考方案2】:

使用与 Erp12 相同的思想,您可以使用该类来包装 keras 模型,动态创建其所有属性,具有装饰器模式的相同精神并扩展 keras 模型,正如 Erp12 建议的那样。

import tempfile
import tensorflow as tf

class PicklableKerasModel(tf.keras.models.Model):

    def __init__(self, model):
        self._model = model

    def __getstate__(self):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            tf.keras.models.save_model(self._model, fd.name, overwrite=True)
            model_str = fd.read()
        d = 'model_str': model_str
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            fd.write(state['model_str'])
            fd.flush()
            model = tf.keras.models.load_model(fd.name)
        self._model = model

    def __getattr__(self, name):
        return getattr(self.__dict__['_model'], name)

    def __setattr__(self, name, value):
        if name == '_model':
            self.__dict__['_model'] = value
        else:
            setattr(self.__dict__['_model'], name, value)

    def __delattr__(self, name):
        delattr(self.__dict__['_model'], name)

然后您就可以使用包装您的 keras 模型的模型,例如:

model = Sequential() # etc etc

picklable_model = PicklableKerasModel(model)

【讨论】:

以上是关于用于 PySpark 的酸洗猴子补丁 Keras 模型的主要内容,如果未能解决你的问题,请参考以下文章

activerecord的一个猴子补丁,用于在失去与mysql服务器的连接后重新连接

python boto的Decimal上下文的猴子补丁,允许浮动的不精确和圆形表示。用于在运行时存储DynamoDB中的任何浮动

Python 3:猴子补丁代码不能通过多处理重新导入

猴子补丁的应用,猴子补丁来改变日志。

Python基础复习函数篇

如何绕过 Angular 使用自己的猴子补丁来撤消我的猴子补丁?