自定义 scikit-learn 酸洗在网格搜索中不起作用
Posted
技术标签:
【中文标题】自定义 scikit-learn 酸洗在网格搜索中不起作用【英文标题】:Custom scikit-learn pickling doesn't work inside a grid search 【发布时间】:2017-10-26 12:28:23 【问题描述】:我写了一个 scikit-learn 估计器。它有一个参数和一个由fit
设置的model_
属性。
class MyEstimator(BaseEstimator, TransformerMixin):
def __init__(self, param="default"):
self.param = param
self.model_ = None
def fit(self, x, y):
# Sets the value of self.model_
我希望能够腌制MyEstimator
,但我创建的model_
对象不能用pickle 序列化,因为它是keras 模型。按照博文“Pickling Keras Models”的示例,我将以下酸洗处理程序方法添加到我的类中。
class MyEstimator(BaseEstimator, TransformerMixin):
def __getstate__(self):
state = super().__getstate__().copy()
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
keras.models.save_model(self.model_, fd.name, overwrite=True)
state["model_"] = fd.read()
return state
def __setstate__(self, state):
super().__setstate__(state)
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
fd.write(state["model_"])
fd.flush()
self.__dict__["model_"] = keras.models.load_model(fd.name)
这将不可腌制的model_
成员替换为由 keras 的序列化程序生成的可以腌制的表示。使用此自定义,我可以调用fit
,进行序列化和反序列化,然后取回我的原始模型。一切正常。
e = MyEstimator()
e.fit(x, y)
with open("myfile.pk", mode="wb") as f:
pickle.dump(e, f)
with open("myfile.pk", mode="rb") as f:
pickle.load(f) # Returns a copy of e
但是,当我尝试将MyEstimator
放入pipeline 并腌制GridSearchCV
的结果时,序列化不起作用。
s = GridSearchCV(Pipeline([
# ...
("estimator", MyEstimator())
# ...
]))
s.fit(x, y)
with open("myfile.pk", mode="wb") as f:
pickle.dump(s, f)
在pickle.dump
调用期间,我希望看到MyEstimator.__getstate__
被一个合适的self.model_
对象调用。 (当我在网格搜索之外自行序列化模型时会发生这种情况。)而self.model_
是None
,所以我无法序列化网格搜索生成的best_estimator_
。
看起来网格搜索序列化正在实例化一个新的MyEstimator
对象,而不是使用管道中的那个。这对我来说似乎是错误的。我查看了 scikit-learn 代码,但看不到这是在哪里发生的。
这是 scikit-learn 中的错误,还是我做错了什么?
(注意:keras 确实有一个 wrapper layer 可以将一些 keras 模型转换为 scikit-learn 估计器,但是由于其他原因我不能在这里使用它,而且我不确定它不会有相同的问题。)
【问题讨论】:
看看这个:github.com/fchollet/keras/issues/4274 大多数 scikit 中的模型评估工具,如cross_val_score
、GridSearchCV
等在拟合之前克隆给定的估计器。在 GridSearchCV 中,您可以看到它克隆的 source code here。
指出特定的源代码行会有所帮助。我在调试器中逐步完成这个过程并且迷路了。我不明白为什么在搜索完成后会有任何没有fit
调用的克隆。
【参考方案1】:
搜索对象包含MyEstimator
对象的混合,其中一些对象尚未调用fit
。解决方法是在尝试使用 keras 工具对其进行序列化之前检查 model_
是否为 None
。
class MyEstimator(BaseEstimator, TransformerMixin):
def __getstate__(self):
state = super().__getstate__().copy()
if self.model_ is not None:
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
keras.models.save_model(self.model_, fd.name, overwrite=True)
state["model_"] = fd.read()
return state
def __setstate__(self, state):
super().__setstate__(state)
if self.model_ is not None:
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
fd.write(state["model_"])
fd.flush()
self.__dict__["model_"] = keras.models.load_model(fd.name)
我不知道为什么网格搜索完成后搜索对象中会有任何未拟合的模型,但是有。
【讨论】:
您是要腌制整个GridSearchCV
对象还是仅腌制 GridSearchCV.best_estimator_
(这基本上是您想要的)
我正在尝试腌制整个搜索对象。我知道GridSearchCV.best_estimator_
是我将在测试时使用的,但我也想比较各种超参数设置的交叉验证分数。以上是关于自定义 scikit-learn 酸洗在网格搜索中不起作用的主要内容,如果未能解决你的问题,请参考以下文章