数据分析——交叉验证

Posted 慢慢来会比较快

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了数据分析——交叉验证相关的知识,希望对你有一定的参考价值。

使用cross_val_score可以做,learning_curve,validition_curve也可以。

from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
%matplotlib inline

iris = load_iris()
x_data = iris.data
y_data = iris.target

k_score = []
for k in range(1,31):
    knn = KNeighborsClassifier(n_neighbors=k)
    score = cross_val_score(knn,x_data,y_data,cv=10,scoring=\'accuracy\')
    k_score.append(score.mean())

plt.figure()
plt.plot(range(1,31),k_score)

from sklearn.learning_curve import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC

digits = load_digits()
svc = SVC()
x_data = digits.data
y_data = digits.target

train_size,train_loss,test_loss = learning_curve(SVC(gamma=0.001),x_data,y_data,cv=10,scoring=\'accuracy\',train_sizes=[0.1,0.25,0.5,0.75,1])

train_loss_mean = train_loss.mean(axis=1)
test_loss_mean = test_loss.mean(axis=1)
plt.plot(train_size,-train_loss_mean,\'r-o\',label=\'train_loss\')
plt.plot(train_size,-test_loss_mean,\'g-o\',label=\'test_loss\')
plt.legend()

from sklearn.learning_curve import validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
digits = load_digits()
x_data = digits.data
y_data = digits.target

train_loss,test_loss = validation_curve(SVC(),x_data,y_data,param_name=\'gamma\',param_range=np.logspace(-6,-2,5),cv=10,scoring=\'accuracy\')
train_loss_mean = train_loss.mean(axis=1)
test_loss_mean = test_loss.mean(axis=1)

plt.figure()
plt.plot(np.logspace(-6,-2,5),-train_loss_mean,\'r-o\',label=\'train_loss\')
plt.plot(np.logspace(-6,-2,5),-test_loss_mean,\'g-o\',label=\'test_loss\')
plt.legend()

 

以上是关于数据分析——交叉验证的主要内容,如果未能解决你的问题,请参考以下文章

机器学习交叉验证和网格搜索案例分析

模型拟合和交叉验证

如何对数据应用交叉验证?

交叉验证和模型选择

sklearn中的交叉验证+决策树

使用 coco 数据格式 json 文件进行交叉验证