sklearn GridSearchCV 与管道
Posted
技术标签:
【中文标题】sklearn GridSearchCV 与管道【英文标题】:sklearn GridSearchCV with Pipeline 【发布时间】:2014-01-29 18:26:09 【问题描述】:我是 sklearn
的 Pipeline
和 GridSearchCV
功能的新手。我正在尝试构建一个管道,该管道首先对我的训练数据执行 RandomizedPCA,然后拟合一个岭回归模型。这是我的代码:
pca = RandomizedPCA(1000, whiten=True)
rgn = Ridge()
pca_ridge = Pipeline([('pca', pca),
('ridge', rgn)])
parameters = 'ridge__alpha': 10 ** np.linspace(-5, -2, 3)
grid_search = GridSearchCV(pca_ridge, parameters, cv=2, n_jobs=1, scoring='mean_squared_error')
grid_search.fit(train_x, train_y[:, 1:])
我知道RidgeCV
函数,但我想试试 Pipeline 和 GridSearch CV。
我希望网格搜索 CV 报告 RMSE 错误,但这似乎在 sklearn 中不支持,所以我正在使用 MSE。但是,它所报告的分数是负数:
In [41]: grid_search.grid_scores_
Out[41]:
[mean: -0.02665, std: 0.00007, params: 'ridge__alpha': 1.0000000000000001e-05,
mean: -0.02658, std: 0.00009, params: 'ridge__alpha': 0.031622776601683791,
mean: -0.02626, std: 0.00008, params: 'ridge__alpha': 100.0]
显然这对于均方误差是不可能的 - 我在这里做错了什么?
【问题讨论】:
【参考方案1】:这些分数是负 MSE 分数,即否定它们,你得到 MSE。问题是GridSearchCV
,按照惯例,总是试图最大化它的分数,所以像 MSE 这样的损失函数必须被否定。
【讨论】:
你能根据你的测试指出任何关于这个或它的文件吗? github.com/scikit-learn/scikit-learn/issues/2439(我个人认为应该是负数而不是“否定”) 我现在有点困惑。我是否必须在 model.compile() 中使用 'neg_mean_squared_error' 来表示“损失”和度量”或“mean_squared_error”?【参考方案2】:创建GridSearchCV
的另一种方法是使用make_scorer
并将greater_is_better
标志转换为False
因此,如果 clf 是您的分类器,而参数是您的超参数列表,您可以像这样使用 make_scorer
:
from sklearn.metrics import make_scorer
#define your own mse and set greater_is_better=False
mse = make_scorer(mean_squared_error,greater_is_better=False)
现在,如下所示,您可以调用 GridSearch 并传递您定义的 mse
grid_obj = GridSearchCV(clf, parameters, cv=5,scoring=mse,n_jobs = -1, verbose=True)
【讨论】:
【参考方案3】:假设,我已将从 GridSearchCV 获得的负 MSE 和负 MAE 的结果分别存储在名为 model_nmse 和 model_nmae 的列表中。
所以我只需将它与 (-1) 相乘,即可获得所需的 MSE 和 MAE 分数。
model_mse = list(np.multiply(model_nmse , -1))
model_mae = list(np.multiply(model_nmae , -1))
【讨论】:
【参考方案4】:如果您想将 RMSE 作为指标,您可以编写自己的可调用/函数,该函数将采用 Y_pred 和 Y_org 并计算 RMSE。
ref
【讨论】:
【参考方案5】:你可以在文档中看到评分
【讨论】:
问题询问为什么 RMSE 值变为负值;这似乎不是问题的答案。 @Gust 有一个 'neg_root_mean_squared_error',我认为得到 RMSE 会很容易吗? @JeremyCaney 感谢您的建议,这里是 scikit learn 评分文档的链接scikit-learn.org/stable/modules/…以上是关于sklearn GridSearchCV 与管道的主要内容,如果未能解决你的问题,请参考以下文章
XGBoost 与 GridSearchCV、缩放、PCA 和 sklearn 管道中的 Early-Stopping
SKLEARN // 将 GridsearchCV 与列变换和管道相结合
sklearn - 如何从传递给 GridSearchCV 的管道中检索 PCA 组件和解释方差
如何实现 sklearn 的 Estimator 接口以在 GridSearchCV 管道中使用?
管道中的自定义 sklearn 转换器为 cross_validate 抛出 IndexError 但在使用 GridSearchCV 时不会