支持向量机的机器学习网格搜索
Posted
技术标签:
【中文标题】支持向量机的机器学习网格搜索【英文标题】:Machine learning gridsearch for svm 【发布时间】:2016-10-09 21:35:08 【问题描述】:我正在做一个项目,我需要计算 gridsearch 返回的最佳估算器。
parameters = 'gamma':[0.1, 0.5, 1, 10, 100], 'C':[1, 5, 10, 100, 1000]
# TODO: Initialize the classifier
svr = svm.SVC()
# TODO: Make an f1 scoring function using 'make_scorer'
f1_scorer = make_scorer(score_func)
# TODO: Perform grid search on the classifier using the f1_scorer as the scoring method
grid_obj = grid_search.GridSearchCV(svr, parameters, scoring=f1_scorer)
# TODO: Fit the grid search object to the training data and find the optimal parameters
grid_obj = grid_obj.fit(X_train, y_train)
pred = grid_obj.predict(X_test)
def score_func():
f1_score(y_test, pred, pos_label='yes')
# Get the estimator
clf = grid_obj.best_estimator_
我不确定如何使 f1_scorer 函数发挥作用,因为我在创建 gridsearch 对象后进行了预测。创建 obj 后我无法声明 f1_scorer,因为 gridsearch 使用它作为评分方法。请帮助我如何为 gridsearch 创建这个评分函数。
【问题讨论】:
【参考方案1】:clf = svm.SVC()
# TODO: Make an f1 scoring function using 'make_scorer'
f1_scorer = make_scorer(f1_score,pos_label='yes')
# TODO: Perform grid search on the classifier using the f1_scorer as the scoring method
grid_obj = GridSearchCV(clf,parameters,scoring=f1_scorer)
# TODO: Fit the grid search object to the training data and find the optimal parameters
grid_obj = grid_obj.fit(X_train, y_train)
# Get the estimator
clf = grid_obj.best_estimator_
【讨论】:
【参考方案2】:您传递给make_scorer
的记分器函数应该采用y_true
和y_pred
作为参数。有了这些信息,您就拥有了计算分数所需的一切。然后 GridSearchCV 将在内部为每个可能的参数集拟合并调用 score 函数,您无需事先计算 y_pred。
应该是这样的:
def score_func(y_true, y_pred):
"""Calculate f1 score given the predicted and expected labels"""
return f1_score(y_true, y_pred, pos_label='yes')
f1_scorer = make_scorer(score_func)
GridSearchCV(svr, parameters, scoring=f1_scorer)
【讨论】:
谢谢!效果很好。如果我可能会问,gridsearch 如何自己返回预测?它与 make_scorer 函数有关吗? 使用与您相同的方法,使用估算器的.predict()
方法。它在内部将数据拆分为验证集和测试集。然后拟合训练集(它是X_train, y_train
的一个子集)并预测并与它的内部测试集(它也是X_train, y_train
的一个子集)进行比较。所以它永远不会使用你的X_test
。那是为了让您在没有偏见的情况下评估您的最终模型以上是关于支持向量机的机器学习网格搜索的主要内容,如果未能解决你的问题,请参考以下文章