6、可视化交叉验证

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了6、可视化交叉验证相关的知识,希望对你有一定的参考价值。

参考技术A 6、可视化交叉验证

from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit,

                                    StratifiedKFold, GroupShuffleSplit,

                                    GroupKFold, StratifiedShuffleSplit)

import numpy as np

import matplotlib.pyplot as plt

from matplotlib.patches import Patch

np.random.seed(1338)

cmap_data = plt.cm.Paired

cmap_cv = plt.cm.coolwarm

n_splits = 4

plt.rcParams['font.sans-serif'] = ['SimHei']

plt.rcParams['axes.unicode_minus'] = False

np.random.seed(1338)

cmap_data = plt.cm.Paired

cmap_cv = plt.cm.coolwarm

n_splits = 4

# 生成类别/组数据

n_points = 100

X = np.random.randn(100, 10)

percentiles_classes = [.1, .3, .6]

y = np.hstack([[ii] * int(100 * perc)

              for ii, perc in enumerate(percentiles_classes)])

# 间隔均匀的组重复一次

groups = np.hstack([[ii] * 10 for ii in range(10)])

def visualize_groups(classes, groups, name):

    # 可视化数据集组

    fig, ax = plt.subplots()

    ax.scatter(range(len(groups)),  [.5] * len(groups), c=groups, marker='_',

              lw=50, cmap=cmap_data)

    ax.scatter(range(len(groups)),  [3.5] * len(groups), c=classes, marker='_',

              lw=50, cmap=cmap_data)

    ax.set(ylim=[-1, 5], yticks=[.5, 3.5],

          yticklabels=['Data\ngroup', 'Data\nclass'], xlabel="Sample index")

visualize_groups(y, groups, 'no groups')

plt.title("可视化数据", fontsize=15)

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):

    """为交叉验证对象的索引创建样本图."""

    # 为每个交叉验证分组生成训练/测试可视化图像

    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):

        # 与训练/测试组一起填写索引

        indices = np.array([np.nan] * len(X))

        indices[tt] = 1

        indices[tr] = 0

        # 可视化结果

        ax.scatter(range(len(indices)), [ii + .5] * len(indices),

                  c=indices, marker='_', lw=lw, cmap=cmap_cv,

                  vmin=-.2, vmax=1.2)

    # 将数据的分组情况和标签情况放入图像

    ax.scatter(range(len(X)), [ii + 1.5] * len(X),

              c=y, marker='_', lw=lw, cmap=cmap_data)

    ax.scatter(range(len(X)), [ii + 2.5] * len(X),

              c=group, marker='_', lw=lw, cmap=cmap_data)

    # 调整格式

    yticklabels = list(range(n_splits)) + ['class', 'group']

    ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,

          xlabel='Sample index', ylabel="CV iteration",

          ylim=[n_splits+2.2, -.2], xlim=[0, 100])

    ax.set_title(''.format(type(cv).__name__), fontsize=15)

    return ax

fig, ax = plt.subplots()

cv = KFold(n_splits)

plot_cv_indices(cv, X, y, groups, ax, n_splits)

fig, ax = plt.subplots()

cv = StratifiedKFold(n_splits)

plot_cv_indices(cv, X, y, groups, ax, n_splits)

cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold,

      GroupShuffleSplit, StratifiedShuffleSplit,

      TimeSeriesSplit]

for cv in cvs:

    this_cv = cv(n_splits=n_splits)

    fig, ax = plt.subplots(figsize=(6, 3))

    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))],

              ['Testing set', 'Training set'], loc=(1.02, .8))

    # Make the legend fit

    plt.tight_layout()

    fig.subplots_adjust(right=.7)

plt.title("可视化交叉验证\n StratifiedKGroupFold", fontsize=15)   

plt.show()

寻找模型最优参数多模型交叉验证可视化指标计算多模型对比可视化(系数图误差图混淆矩阵校正曲线ROC曲线AUCAccuracy特异度灵敏度PPVNPV)

使用randomsearchcv寻找模型最优参数、多模型交叉验证、可视化、指标计算、多模型对比可视化(系数图、误差图、classification_report、混淆矩阵、校正曲线、ROC曲线、AUC、Accuracy、特异度、灵敏度、PPV、NPV)

目录

以上是关于6、可视化交叉验证的主要内容,如果未能解决你的问题,请参考以下文章

python基于tpot训练模型在获得最佳模型之后对模型进行交叉验证分析并可视化实战

树的随机森林数和交叉验证

R语言使用yardstick包评估模型性能(二分类多分类回归模型交叉验证每一折的指标npvppvaccuracyauckapparecallrmsemaer2等以及可视化)

R语言使用yardstick包评估模型性能(二分类多分类回归模型交叉验证每一折的指标npvppvaccuracyauckapparecallrmsemaer2等以及可视化)

机器学习交叉验证和网格搜索案例分析

R语言使用yardstick包的lift_curve函数评估多分类(Multiclass)模型的性能并使用autoplot函数可视化模型在每个交叉验证(或者重采样)的每一折fold在每个分类上的提升