数据分析——交叉验证
Posted 慢慢来会比较快
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了数据分析——交叉验证相关的知识,希望对你有一定的参考价值。
使用cross_val_score可以做,learning_curve,validition_curve也可以。
from sklearn.datasets import load_iris from sklearn.cross_validation import cross_val_score from sklearn.neighbors import KNeighborsClassifier import matplotlib.pyplot as plt %matplotlib inline iris = load_iris() x_data = iris.data y_data = iris.target k_score = [] for k in range(1,31): knn = KNeighborsClassifier(n_neighbors=k) score = cross_val_score(knn,x_data,y_data,cv=10,scoring=\'accuracy\') k_score.append(score.mean()) plt.figure() plt.plot(range(1,31),k_score)
from sklearn.learning_curve import learning_curve from sklearn.datasets import load_digits from sklearn.svm import SVC digits = load_digits() svc = SVC() x_data = digits.data y_data = digits.target train_size,train_loss,test_loss = learning_curve(SVC(gamma=0.001),x_data,y_data,cv=10,scoring=\'accuracy\',train_sizes=[0.1,0.25,0.5,0.75,1]) train_loss_mean = train_loss.mean(axis=1) test_loss_mean = test_loss.mean(axis=1) plt.plot(train_size,-train_loss_mean,\'r-o\',label=\'train_loss\') plt.plot(train_size,-test_loss_mean,\'g-o\',label=\'test_loss\') plt.legend()
from sklearn.learning_curve import validation_curve from sklearn.datasets import load_digits from sklearn.svm import SVC digits = load_digits() x_data = digits.data y_data = digits.target train_loss,test_loss = validation_curve(SVC(),x_data,y_data,param_name=\'gamma\',param_range=np.logspace(-6,-2,5),cv=10,scoring=\'accuracy\') train_loss_mean = train_loss.mean(axis=1) test_loss_mean = test_loss.mean(axis=1) plt.figure() plt.plot(np.logspace(-6,-2,5),-train_loss_mean,\'r-o\',label=\'train_loss\') plt.plot(np.logspace(-6,-2,5),-test_loss_mean,\'g-o\',label=\'test_loss\') plt.legend()
以上是关于数据分析——交叉验证的主要内容,如果未能解决你的问题,请参考以下文章