在 scikit learn 中结合网格搜索和交叉验证
Posted
技术标签:
【中文标题】在 scikit learn 中结合网格搜索和交叉验证【英文标题】:Combining Grid search and cross validation in scikit learn 【发布时间】:2013-01-29 17:56:11 【问题描述】:为了改进支持向量机的结果,我必须使用网格搜索来搜索更好的参数和交叉验证。 我不确定如何在 scikit-learn 中组合它们。 网格搜索搜索最佳参数(http://scikit-learn.org/stable/modules/grid_search.html)和交叉验证避免过拟合(http://scikit-learn.org/dev/modules/cross_validation.html)
#GRID SEARCH
from sklearn import grid_search
parameters = 'kernel':('linear', 'rbf'), 'C':[1, 10]
svr = svm.SVC()
clf = grid_search.GridSearchCV(svr, parameters)
#print(clf.fit(X, Y))
#CROSS VALIDATION
from sklearn import cross_validation
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, Y, test_size=0.4, random_state=0)
clf = svm.SVC(kernel='linear', C=1).fit(X_train, y_train)
print("crossvalidation")
print(clf.score(X_test, y_test))
clf = svm.SVC(kernel='linear', C=1)
scores = cross_validation.cross_val_score(clf, X, Y, cv=3)
print(scores )
结果:
GridSearchCV(cv=None,
estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel=rbf, probability=False, shrinking=True, tol=0.001, verbose=False),
estimator__C=1.0, estimator__cache_size=200,
estimator__class_weight=None, estimator__coef0=0.0,
estimator__degree=3, estimator__gamma=0.0, estimator__kernel=rbf,
estimator__probability=False, estimator__shrinking=True,
estimator__tol=0.001, estimator__verbose=False, fit_params=,
iid=True, loss_func=None, n_jobs=1,
param_grid='kernel': ('linear', 'rbf'), 'C': [1, 10],
pre_dispatch=2*n_jobs, refit=True, score_func=None, verbose=0)
crossvalidation
0.0
[ 0.11111111 0.11111111 0. ]
【问题讨论】:
【参考方案1】:您应该先进行开发/评估拆分,在开发部分运行网格搜索,最后在评估部分测量唯一的最终分数:
the documentation中有an example。
【讨论】:
我尝试使用我的数据运行,但出现此错误:clf = GridSearchCV(SVC(C=1), tune_parameters,scoring=score) TypeError: __init__() got an unexpected keyword argument 'scoring ',我也尝试运行原始示例并且存在相同的错误,但这怎么可能呢?给它打分是一个函数参数! 检查文档的版本号并选择与您安装的版本相匹配的版本。每个版本的 URL 都不同:scikit-learn.org/dev/modules/grid_search.html 是开发分支。 scikit-learn.org/stable/modules/grid_search.html 是最后一个发布版本(撰写本文时为 0.13),scikit-learn.org/0.13/modules/grid_search.html 是 0.13 版本的固定 URL。 我修复了指向文档稳定版本的答案。以上是关于在 scikit learn 中结合网格搜索和交叉验证的主要内容,如果未能解决你的问题,请参考以下文章