使用来自 gridsearchcv 的最佳参数
Posted
技术标签:
【中文标题】使用来自 gridsearchcv 的最佳参数【英文标题】:using best params from gridsearchcv 【发布时间】:2017-05-19 10:27:37 【问题描述】:我不知道在这里问这个问题是否正确,但无论如何我都会问。如果不允许,请告诉我。
我使用GridSearchCV
调整参数以找到最佳精度。这就是我所做的:
from sklearn.grid_search import GridSearchCV
parameters = 'min_samples_split':np.arange(2, 80), 'max_depth': np.arange(2,10), 'criterion':['gini', 'entropy']
clfr = DecisionTreeClassifier()
grid = GridSearchCV(clfr, parameters,scoring='accuracy', cv=8)
grid.fit(X_train,y_train)
print('The parameters combination that would give best accuracy is : ')
print(grid.best_params_)
print('The best accuracy achieved after parameter tuning via grid search is : ', grid.best_score_)
这给了我以下结果:
The parameters combination that would give best accuracy is :
'max_depth': 5, 'criterion': 'entropy', 'min_samples_split': 2
The best accuracy achieved after parameter tuning via grid search is : 0.8147086914995224
现在,我想在调用可视化决策树的函数时使用这些参数
函数看起来像这样
def visualize_decision_tree(decision_tree, feature, target):
dot_data = export_graphviz(decision_tree, out_file=None,
feature_names=feature,
class_names=target,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
return Image(graph.create_png())
现在我正在尝试使用 GridSearchCV 提供的最佳参数,以如下方式调用函数
dtBestScore = DecisionTreeClassifier(parameters = grid.best_params_)
dtBestScore = dtBestScore.fit(X=dfWithTrainFeatures, y= dfWithTestFeature)
visualize_decision_tree(dtBestScore, list(dfCopy.columns.delete(0).values), 'survived')
我在第一行代码中遇到错误
TypeError: __init__() got an unexpected keyword argument 'parameters'
有什么方法可以让我设法使用网格搜索提供的最佳参数并自动使用它?而不是查看结果并手动设置每个参数的值?
【问题讨论】:
python kwargs 不像DecisionTreeClassifier(**grid.best_params)
那样工作吗?有关 kwargs 的更多信息,请参阅pythontips.com/2013/08/04/args-and-kwargs-in-python-explained。
效果惊人。您可以将其写为答案,我可以接受。我对这件事很陌生,不太了解,这对我有很大帮助
添加为答案。谢谢。
【参考方案1】:
尝试 python kwargs:
DecisionTreeClassifier(**grid.best_params)
有关 kwargs 的更多信息,请参阅http://pythontips.com/2013/08/04/args-and-kwargs-in-python-explained。
【讨论】:
如果您优化了管道,最好的方法是什么?带有“pipelinestep__”的键前缀似乎会损害参数的映射? ?) param_dict = x.replace("pipelinestep__", ""): v for x, v in param_dict.items()以上是关于使用来自 gridsearchcv 的最佳参数的主要内容,如果未能解决你的问题,请参考以下文章
GridSearchCV 的替代方法,用于查找 SVM 模型的参数
使用最佳参数构建模型时 GridsearchCV 最佳分数下降
如何将最佳参数(使用 GridSearchCV)从管道传递到另一个管道
ValueError 在 Scikit 中找到最佳超参数时使用 GridSearchCV 学习 LogisticRegression