sklearn获得某个参数的不同取值在训练集和测试集上的表现的曲线刻画

Posted wzd321

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了sklearn获得某个参数的不同取值在训练集和测试集上的表现的曲线刻画相关的知识,希望对你有一定的参考价值。

from sklearn.svm import SVC
from sklearn.datasets import make_classification
import numpy as np

X,y = make_classification()


def plot_validation_curve(estimator,X,y,param_name="gamma",
                          param_range=np.logspace(-6,-1,5),cv=5,scoring="accuracy"):
    """
    描述:获得某个参数的不同取值在训练集和测试集上的表现
    """
    from sklearn.model_selection import validation_curve
    import matplotlib.pyplot as plt
    
    train_scores,test_scores = validation_curve(estimator=estimator, 
                                                X=X, 
                                                y=y, 
                                                cv=cv,
                                                scoring=scoring,
                                                param_name=param_name,
                                                param_range=param_range)
    
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std  = np.std(train_scores, axis=1)
    test_scores_mean  = np.mean(test_scores, axis=1)
    test_scores_std   = np.std(test_scores, axis=1)
    
    plt.title("Validation Curve")
    plt.xlabel("$\gamma$")
    plt.ylabel("Score")
    plt.ylim(0.0, 1.1)
    
    plt.semilogx(param_range,train_scores_mean,label="Training score",color="darkorange", lw=2)
    plt.fill_between(param_range,
                     train_scores_mean-train_scores_std,
                     train_scores_mean+train_scores_std,
                     alpha=0.2,
                     color="darkorange", 
                     lw=2)
    
    plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",color="navy", lw=2)    
    plt.fill_between(param_range, 
                     test_scores_mean - test_scores_std,
                     test_scores_mean + test_scores_std, 
                     alpha=0.2,
                     color="navy", 
                     lw=2)
    
    plt.legend(loc="best")
    plt.show()
    

    
plot_validation_curve(estimator=SVC(),
                      X=X,y=y,
                      param_name="gamma",
                      param_range=np.logspace(-6,-1,5),cv=5,scoring="accuracy")    
    

 

以上是关于sklearn获得某个参数的不同取值在训练集和测试集上的表现的曲线刻画的主要内容,如果未能解决你的问题,请参考以下文章

sklearn——train_test_split 随机划分训练集和测试集

使用单独的预定义验证集和 sklearn GridSearchCV

[转][python sklearn模型中random_state参数的意义]

调参-网格搜索(Grid Search)

KNN算法网格搜索最优参数

Sklearn-CrossValidation 交叉验证