K折验证

Posted wbloger

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了K折验证相关的知识,希望对你有一定的参考价值。

"""K折验证"""

#K validation

import numpy as np

k = 4
num_val_samples = len(train_data) // k
num_epochs = 100
all_scores = []

for i in range(k):
    print("processing fold #", i)
    val_data = train_data[i * num_val_samples:(i + 1) * num_val_samples]
    val_targets = train_targets[i * num_val_samples : (i + 1) * num_val_samples]
    
    partial_train_data = np.concatenate([
            train_data[: i * num_val_samples], train_data[(i+1) * num_val_samples :]], axis = 0)
    
    partial_train_targets = np.concatenate([
            train_targets[: i * num_val_samples], train_targets[ (i + 1) * num_val_samples :]], axis = 0)
        
    model = build_model()
    model.fit(partial_train_data, partial_train_targets, 
              epochs = num_epochs, batch_size = 1, verbose = 0)
    val_mse, val_mae = model.evaluate(val_data, val_targets, verbose = 0)
    all_scores.append(val_mae)

all_scores

np.mean(all_scores)

"""保存每折验证的结果"""

#save k-viladation results
num_epochs = 500
all_mae_histories=  []
for i in range(k):
    print("processing fold #", i)
    val_data = train_data[i * num_val_samples : (i + 1) * num_val_samples]
    val_targets = train_targets[i * num_val_samples : (i + 1) * num_val_samples]
    
    partial_train_data = np.concatenate(
            [train_data[: i * num_val_samples], train_data[(i + 1) * num_val_samples :]], axis = 0)
    partial_train_targets = np.concatenate([train_targets[: i * num_val_samples], train_targets[(i + 1) * num_val_samples :]], axis = 0)
    
    model = build_model()
    history = model.fit(partial_train_data, partial_train_targets,
                        validation_data = (val_data, val_targets), 
                        epochs = num_epochs, batch_size = 1, verbose = 0)
    mae_history = history.history[val_mean_absolute_error]
    all_mae_histories.append(mae_history)

 

以上是关于K折验证的主要内容,如果未能解决你的问题,请参考以下文章

在 MLPClassification Python 中实现 K 折交叉验证

小白学习之pytorch框架之实战Kaggle比赛:房价预测(K折交叉验证*args**kwargs)

使用 K 折交叉验证标准化数据

在 MATLAB 中测试模型准确性的 K 折交叉验证

在使用 k 折交叉验证训练训练数据后如何测试数据?

如何计算分层 K 折交叉验证的不平衡数据集的误报率?