交叉验证
Posted ywheunji
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了交叉验证相关的知识,希望对你有一定的参考价值。
PS: sklearn.cross_validation已经被移除所以改为从 sklearn.model_selection 中调用train_test_split 等函数
K层交叉检验就是把原始的数据随机分成K个部分。在这K个部分中,选择一个作为测试数据,剩下的K-1个
作为训练数据。 交叉检验的过程实际上是把实验重复做K次,每次实验都从K个部分中选择一个不同的部分
作为测试数据(保证K个部分的数据都分别做过测试数据),剩下的K-1个当作训练数据进行实验,最后把
得到的K个实验结果平均。
验证模型的精确度,或者某个参数的影响
scores = cross_val_score(knn, X, y, cv=10, scoring=‘accuracy‘)
其中accuracy一般用于分类问题的评估,而均方差一般用于回归问题的评估。
参考代码如下:
import matplotlib.pyplot as plt #可视化模块
#建立测试参数集
k_range = range(1, 31)
k_scores = []
#藉由迭代的方式来计算不同参数对模型的影响,并返回交叉验证后的平均准确率
for k in k_range:
knn = KNeighborsClassifier(n_neighbors=k)
scores = cross_val_score(knn, X, y, cv=10, scoring=‘accuracy‘)
#X,y是原始数据集未分割,knn为初始模型未训练
k_scores.append(scores.mean())
#可视化数据
plt.plot(k_range, k_scores)
plt.xlabel(‘Value of K for KNN‘)
plt.ylabel(‘Cross-Validated Accuracy‘)
plt.show()
交叉验证时候查看训练过程中的学习效果,利用Learning curve 检视过拟合
1 train_sizes, train_loss, test_loss = learning_curve( 2 SVC(gamma=0.001), X, y, cv=10, scoring=‘mean_squared_error‘, 3 train_sizes=[0.1, 0.25, 0.5, 0.75, 1]) 4 5 #平均每一轮所得到的平均方差(共5轮,分别为样本10%、25%、50%、75%、100%)
train_size的参数是指在训练数据的10%,25%时候的误差。
主要内容:在调节模型的某个参数时候可以用validation_curve来测试Loss值的变化
from sklearn.learning_curve import validation_curve #validation_curve模块
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
#digits数据集
digits = load_digits()
X = digits.data
y = digits.target
#建立参数测试集
param_range = np.logspace(-6, -2.3, 5)
#使用validation_curve快速找出参数对模型的影响
train_loss, test_loss = validation_curve(
SVC(), X, y, param_name=‘gamma‘, param_range=param_range, cv=10, scoring=‘mean_squared_error‘)
#平均每一轮的平均方差
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
#可视化图形
plt.plot(param_range, train_loss_mean, ‘o-‘, color="r",
label="Training")
plt.plot(param_range, test_loss_mean, ‘o-‘, color="g",
label="Cross-validation")
plt.xlabel("gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()
param_range为参数的取值范围以log形式表现, 模型中SVC()不传参数,在param_name中传入要测试的参数名称,在param-range赋予取值范围。
以上是关于交叉验证的主要内容,如果未能解决你的问题,请参考以下文章