scikit-learn 中自定义内核 SVM 的交叉验证

Posted

技术标签:

【中文标题】scikit-learn 中自定义内核 SVM 的交叉验证【英文标题】:Cross validation for custom kernel SVM in scikit-learn 【发布时间】:2015-09-29 21:15:16 【问题描述】:

我想通过交叉验证使用 scikit-learn 对自定义内核 SVM 进行网格搜索。更准确地关注this example 我想定义一个像

这样的核函数
def my_kernel(x, y):
"""
We create a custom kernel:
k(x, y) = x * M *y.T          
"""
return np.dot(np.dot(x, M), y.T)

其中 M 是内核的一个参数(如高斯内核中的 gamma)。

我想通过GridSearchCV 提供这个参数 M,类似于

parameters = 'kernel':('my_kernel'), 'C':[1, 10], 'M':[M1,M2]
svr = svm.SVC()
clf = grid_search.GridSearchCV(svr, parameters)

所以我的问题是:如何定义 my_kernel 以便 M 变量将由 GridSearchCV 给出?

【问题讨论】:

【参考方案1】:

您可能需要创建一个包装类。比如:

class MySVC(BaseEstimator,ClassifierMixin):
    def __init__( self, 
              # all the SVC attributes
              M ):
         self.M = M
         # etc...

    def fit( self, X, y ):
         kernel = lambda x,y : np.dot(np.dot(x,M),y.T)
         self.svc_ = SVC( kernel=kernel, # the other parameters )
         return self.svc_.fit( X, y )
    def predict( self, X ):
         return self.svc_.predict( X )
    # et cetera

【讨论】:

以上是关于scikit-learn 中自定义内核 SVM 的交叉验证的主要内容,如果未能解决你的问题,请参考以下文章

如何为 sklearn.svm.SVC 定义自定义内核函数?

Scikit-learn 的带有线性内核 svm 的 GridSearchCV 耗时太长

在 python scikit-learn 中,RBF 内核的性能比 SVM 中的线性差得多

绘制scikit-learn(sklearn)SVM决策边界/表面

使用 scikit-learn python 的线性 SVM 时出现 ValueError

为啥 scikit-learn SVM 分类器交叉验证这么慢?