用于管道的网格搜索参数网格的说明
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用于管道的网格搜索参数网格的说明相关的知识,希望对你有一定的参考价值。
此'feature_selection__k': list(range(1, len(feature_importances) + 1))
代码在什么意思:
param_grid = [{
'preparation__num__imputer__strategy': ['mean', 'median', 'most_frequent'],
'feature_selection__k': list(range(1, len(feature_importances) + 1))
}]
grid_search_prep = GridSearchCV(prepare_select_and_predict_pipeline, param_grid, cv=5,
scoring='neg_mean_squared_error', verbose=2)
grid_search_prep.fit(housing, housing_labels)
where
full_pipeline = ColumnTransformer([
("num", num_pipeline, num_attribs),
("cat", OneHotEncoder(), cat_attribs),
])
和
num_pipeline = Pipeline([
('imputer', SimpleImputer(strategy="median")),
('attribs_adder', CombinedAttributesAdder()),
('std_scaler', StandardScaler()),
])
您能解释一下'feature_selection__k': list(range(1, len(feature_importances) + 1))
行中的每一步吗?
如果需要所有代码,则在这里:https://github.com/ageron/handson-ml2/blob/master/02_end_to_end_machine_learning_project.ipynb。我要问的代码部分在笔记本底部。
因为GridSearchCV
在这里不适用于简单的估算器,但在管道中适用:
prepare_select_and_predict_pipeline = Pipeline([
('preparation', full_pipeline),
('feature_selection', TopFeatureSelector(feature_importances, k)),
('svm_reg', SVR(**rnd_search.best_params_))
])
where
full_pipeline = ColumnTransformer([
("num", num_pipeline, num_attribs),
("cat", OneHotEncoder(), cat_attribs),
])
和
num_pipeline = Pipeline([
('imputer', SimpleImputer(strategy="median")),
('attribs_adder', CombinedAttributesAdder()),
('std_scaler', StandardScaler()),
])
param_grid
需要组织成“级别”,以便它知道它将访问组成管道组件的确切参数。
因此,字符串的feature_selection
部分是指prepare_select_and_predict_pipeline
的各个部分,k
是指TopFeatureSelector
的各个参数。级别由双下划线__
分隔,因此要访问k
的参数TopFeatureSelector
,相应的参数定义为feature_selection__k
。
出于相同的原因,为了使param_grid
访问strategy
的SimpleImputer
参数,相应的条目为'preparation__num__imputer__strategy'
,即:
- 第一个
preparation
的prepare_select_and_predict_pipeline
部分 - 第二个
num
的full_pipeline
部分 imputer
的num_pipeline
部分strategy
的参数SimpleImputer
每个条目,如前所述,用双下划线__
分隔。
[list(range(1, len(feature_importances) + 1))
返回一个从1到n的整数列表,其中n是feature_importances
的长度。
例如,如果feature_importances
的长度为5。
list(range(1, len(feature_importances) + 1))
将返回:
[1, 2, 3, 4, 5]
您需要添加“ +1”,因为范围不包括上限。
以上是关于用于管道的网格搜索参数网格的说明的主要内容,如果未能解决你的问题,请参考以下文章
scikit-learn 管道:对变压器参数进行网格搜索以生成数据