带有 RandomForest 的 GridsearchCV

Posted

技术标签:

【中文标题】带有 RandomForest 的 GridsearchCV【英文标题】:GridsearchCV with RandomForest 【发布时间】:2018-01-06 07:02:46 【问题描述】:

所以我正在用 RandomForest 和 GridsearchCV 做一些参数的事情。这是我的代码。

#Import 'GridSearchCV' and 'make_scorer'
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer

Create the parameters list you wish to tune
parameters = 'n_estimators':[5,10,15]

#Initialize the classifier
clf = GridSearchCV(RandomForestClassifier(), parameters)

#Make an f1 scoring function using 'make_scorer' 
f1_scorer = make_scorer(f1_scorer)

#Perform grid search on the classifier using the f1_scorer as the scoring method
grid_obj = GridSearchCV(clf, param_grid=parameters, scoring=f1_scorer,cv=5)

print(clf.get_params().keys())

#Fit the grid search object to the training data and find the optimal parameters
grid_obj = grid_obj.fit(X_train_100,y_train_100)

所以问题是以下错误:“ValueError: Invalid parameter max_features for estimator GridSearchCV。使用estimator.get_params().keys()检查可用参数列表。”

我遵循了错误给出的建议, print(clf.get_params().keys()) 的输出如下。但是,即使我将这些标题复制并粘贴到我的参数字典中,我仍然会收到错误消息。我一直在寻找堆栈溢出,大多数人都在使用与我非常相似的参数字典。任何人都知道如何解决这个问题?再次感谢!

dict_keys(['pre_dispatch', 'cv', 'estimator__max_features', 'param_grid', 'refit', 'estimator__min_impurity_split', 'n_jobs', 'estimator__random_state', 'error_score', 'verbose', 'estimator__min_samples_split', 'estimator__n_jobs'、'fit_params'、'estimator__min_weight_fraction_leaf'、'scoring'、'estimator__warm_start'、'estimator__criterion'、'estimator__verbose'、'estimator__bootstrap'、'estimator__class_weight'、'estimator__oob_depth'、'iid'、'estimator'、'estimator__max ', 'estimator__max_leaf_nodes', 'estimator__min_samples_leaf', 'estimator__n_estimators', 'return_train_score'])

【问题讨论】:

【参考方案1】:

我认为问题在于这两行:

clf = GridSearchCV(RandomForestClassifier(), parameters)
grid_obj = GridSearchCV(clf, param_grid=parameters, scoring=f1_scorer,cv=5)

这实际上是在创建一个具有如下结构的对象:

grid_obj = GridSearchCV(GridSearchCV(RandomForestClassifier()))

这可能比你想要的多一个GridSearchCV

【讨论】:

以上是关于带有 RandomForest 的 GridsearchCV的主要内容,如果未能解决你的问题,请参考以下文章

使用并行训练带有插入符号的随机森林

R中RandomForest包中的RandomForest函数中的参数'classwt'代表啥?

用于 R 中回归的 RandomForest

获取 RandomForest 中单个树的重要性

时间序列回归 - RandomForest

R语言随机森林回归(randomforest)模型构建