Sklearn gridsearchCV 对象在 pickle 转储/加载后更改
Posted
技术标签:
【中文标题】Sklearn gridsearchCV 对象在 pickle 转储/加载后更改【英文标题】:Sklearn gridsearchCV object changed after pickle dump/load 【发布时间】:2017-08-16 11:46:00 【问题描述】:我有一个我创建的 gridsearchCV 对象
grid_search = GridSearchCV(pred_home_pipeline, param_grid)
我想保存整个网格搜索对象,以便稍后探索模型调整结果。我不想只保存the best_estimator_
。但是在转储和重新加载之后,重新加载的和原始的 grid_search 对象在某些方面有所不同,我无法追踪。
# save to disk
with open(filepath, 'wb') as handle:
pickle.dump(grid_search, handle, protocol=pickle.HIGHEST_PROTOCOL)
# reload
with open(filepath, 'rb') as handle:
grid_reloaded = pickle.load(handle)
# test object is unchanged after dump/reload
print(grid_search == grid_reloaded)
错误
很奇怪。查看print(grid_search)
和print(grid_reloaded)
的输出,它们看起来确实是一样的。
他们为我完全从网格搜索过程中提取的数据创建了完全相同的一组 525 个预测值:
grid_search_preds = grid_search.predict(X_test)
grid_reloaded_preds= grid_reloaded.predict(X_test)
(grid_search_preds == grid_reloaded_preds).all()
是的
...尽管best_estimator_
属性在技术上并不相同:
grid_search.best_estimator_ == grid_reloaded.best_estimator_
错误
...虽然 best_estimate_ 属性在比较 print(grid_search.best_estimatmator_)
和 print(grid_reloaded.best_estimator_)
时看起来也一样
这里发生了什么?保存 gridsearchcv 对象以供以后检查是否安全?
【问题讨论】:
我猜网格搜索对象根本没有定义“基于功能”的平等概念。只有当它们是完全相同的对象时,它们才可能被认为是相等的。尝试创建两个相同的 GridSearch 对象(通过两次运行相同的创建代码)并查看它们是否相等;我猜他们不会。这可能意味着您确实可以像往常一样使用腌制对象,但它不会“看起来”等于其他等效对象(就从您的==
测试中获得真实值而言)。
【参考方案1】:
那是因为比较返回的是对象是否是同一个对象。
要了解原因,请遵循对象层次结构,您会发现没有覆盖 __eq__
函数(或 __cmp__
):
因此,“==”比较回退到对象内存位置比较,对于该比较,您重新加载的实例和当前实例当然不能相等。这是比较以查看它们是否是同一个对象。
查看更多here。
【讨论】:
非常感谢。您在此处链接到的堆栈溢出答案中链接的博客文章真的很吸引人。 @MaxPower 如果回答了您的问题,请选中复选标记【参考方案2】:这里是来自 sklearn 的 github 的 sklearn contributor GaelVaroquaux's answer,关于为什么这里没有实现 __eq__
方法,以及测试两个 sklearn 对象相等性的解决方案:
不,我宁愿不添加 eq。这些东西很难 做对了,人们不应该期望一个库来实现 eq on 复杂的对象。
您可以做的一件事是使用 joblib.hash 来计算 MD5 哈希 的对象,并使用它进行比较。
【讨论】:
以上是关于Sklearn gridsearchCV 对象在 pickle 转储/加载后更改的主要内容,如果未能解决你的问题,请参考以下文章
GridSearchCV 的 sklearn 中的自定义“k 精度”评分对象
你能从 sklearn 网格搜索 (GridSearchCV) 中获得所有估计器吗?