如何保存 GridSearchCV 对象?

Posted

技术标签:

【中文标题】如何保存 GridSearchCV 对象?【英文标题】:How to save GridSearchCV object? 【发布时间】:2018-12-27 16:33:19 【问题描述】:

最近,我一直致力于在带有 Tensorflow 后端的 Keras 中应用网格搜索交叉验证 (sklearn GridSearchCV) 进行超参数调整。我的模型调整好后 我正在尝试保存 GridSearchCV 对象以供以后使用,但没有成功。

超参数调优如下:

x_train, x_val, y_train, y_val = train_test_split(NN_input, NN_target, train_size = 0.85, random_state = 4)

history = History() 
kfold = 10


regressor = KerasRegressor(build_fn = create_keras_model, epochs = 100, batch_size=1000, verbose=1)

neurons = np.arange(10,101,10) 
hidden_layers = [1,2]
optimizer = ['adam','sgd']
activation = ['relu'] 
dropout = [0.1] 

parameters = dict(neurons = neurons,
                  hidden_layers = hidden_layers,
                  optimizer = optimizer,
                  activation = activation,
                  dropout = dropout)

gs = GridSearchCV(estimator = regressor,
                  param_grid = parameters,
                  scoring='mean_squared_error',
                  n_jobs = 1,
                  cv = kfold,
                  verbose = 3,
                  return_train_score=True))

grid_result = gs.fit(NN_input,
                    NN_target,
                    callbacks=[history],
                    verbose=1,
                    validation_data=(x_val, y_val))

备注:create_keras_model 函数初始化并编译一个 Keras Sequential 模型。

执行交叉验证后,我尝试使用以下代码保存网格搜索对象 (gs):

from sklearn.externals import joblib

joblib.dump(gs, 'GS_obj.pkl')

我得到的错误如下:

TypeError: can't pickle _thread.RLock objects

请告诉我这个错误的可能原因是什么?

谢谢!

P.S.:joblib.dump 方法适用于保存使用的 GridSearchCV 对象 用于训练来自 sklearn 的 MLPRegressors。

【问题讨论】:

如果我的回答解决了您的问题,请告诉我。 【参考方案1】:

使用

import joblib直接

而不是

from sklearn.externals import joblib

保存对象或结果:

joblib.dump(gs, 'model_file_name.pkl')

并使用以下方法加载您的结果:

joblib.load("model_file_name.pkl")

这是一个简单的工作示例:


import joblib

#save your model or results
joblib.dump(gs, 'model_file_name.pkl')

#load your model for further usage
joblib.load("model_file_name.pkl")

【讨论】:

【参考方案2】:

试试这个:

from sklearn.externals import joblib
joblib.dump(gs.best_estimator_, 'filename.pkl')

如果您想将对象转储到一个文件中 - 使用:

joblib.dump(gs.best_estimator_, 'filename.pkl', compress = 1)

简单示例:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from sklearn.externals import joblib

iris = datasets.load_iris()
parameters = 'kernel':('linear', 'rbf'), 'C':[1, 10]
svc = svm.SVC()
gs = GridSearchCV(svc, parameters)
gs.fit(iris.data, iris.target)

joblib.dump(gs.best_estimator_, 'filename.pkl')

#['filename.pkl']

编辑 1:

你也可以保存整个对象:

joblib.dump(gs, 'gs_object.pkl')

【讨论】:

感谢您的回复!如果我没记错的话,您提出的方法是仅保存具有最佳调整参数(最佳估计器)的模型。但是,我想做的是保存 GridSearchCV 对象中包含的所有信息,这意味着所有训练模型的性能信息。一种方法是保存 gs.cv_results_ 而不是整个对象,但我只是想知道为什么我不允许将整个对象保存在文件中。 您可以使用joblib.dump(gs, 'gs_object.pkl') 保存整个对象。查看我编辑的答案 正如我的问题中所述,我已经尝试过这种方法来保存整个对象,但它不起作用。我还没弄明白为什么。 @E.Thrampoulidis 我自己也在做这个。问题是 GridSearchCV 旨在通过 n_jobs 参数支持并行性。据我所知,没有简单的方法来腌制支持并行调用的对象(因此出现有关腌制线程的错误)。 Pickle 非常适合字典 (cv_results) 等简单的数据结构,但对于最初从未打算用于序列化的复杂对象(例如 GridSearchCV 类)来说,它不是一个好的选择。 joblib 自 scikit 0.21 起已弃用,并将在 0.23 中删除。现在,它需要通过 pip (pip install joblib) 或 conda (conda install -c anaconda joblib) 作为单独的包安装【参考方案3】:

子类化sklearn.model_selection._search.BaseSearchCV 类。覆盖fit(self, X, y=None, groups=None, **fit_params) 方法,并修改其内部evaluate_candidates(candidate_params) 函数。不要立即从evaluate_candidates(candidate_params) 返回results 字典,而是在此处执行序列化(或在_run_search 方法中执行,具体取决于您的用例)。通过一些额外的修改,这种方法还有一个额外的好处,即允许您按顺序执行网格搜索(请参阅此处源代码中的注释:_search.py)。请注意,evaluate_candidates(candidate_params) 返回的 results 字典与 cv_results 字典相同。这种方法对我有用,但我也尝试为中断的网格搜索执行添加保存和恢复功能。

【讨论】:

嗨,克里斯!您是否能够保存和恢复中断的网格搜索?我想对 BayesSearchCV(来自 Scikit-Optimize 库)做类似的事情,它使用与 GridSearchCV 类似的接口。 @SergeGardien 是的,但这不是一个快速解决方案。您必须修改核心库中的一些方法。最好只维护自己的 cv_results 字典并从中进行序列化和恢复。 明白,谢谢。问题是 BayesSearchCV 是路径相关的,与 GridSearchCV 不同,我认为简单地存储 cv_results 不足以让所有信息恢复过程。无论如何,如果我能找到一些时间,我会看看,否则我会尽量不要处于需要恢复优化过程的情况。 @SergeGardien 我很乐意一有机会就提供更多详细信息。祝你好运!

以上是关于如何保存 GridSearchCV 对象?的主要内容,如果未能解决你的问题,请参考以下文章

如何保存 GridSearchCV 对象?

在 GridSearchCV 中,如何只传递 param_grid 中的默认参数?

什么取代了 scikit 中的 GridSearchCV._grid_scores_?

在 GridSearchCV 中使用精度作为评分时如何指定正标签

Scikit-learn 的 GridSearchCV 中的 Grid_scores_ 是啥意思

为啥 sklearn.grid_search.GridSearchCV 在每次执行时都会返回随机结果?