为啥 50-50 的训练/测试拆分最适合使用此神经网络的 178 个观察数据集?

Posted

技术标签:

【中文标题】为啥 50-50 的训练/测试拆分最适合使用此神经网络的 178 个观察数据集?【英文标题】:Why does a 50-50 train/test split work best for a data-set of 178 observations with this neural network?为什么 50-50 的训练/测试拆分最适合使用此神经网络的 178 个观察数据集? 【发布时间】:2018-03-30 21:23:56 【问题描述】:

从我所读到的内容看来,大约 80% 的训练和 20% 的 va 验证数据接近最优。随着测试数据集大小的增加,验证结果的方差应该会降低,但代价是训练效率降低(验证准确度降低)。

因此,我对以下结果感到困惑,这些结果似乎与TEST_SIZE=0.5 显示出最佳准确性和低方差(每个试验运行多次,并选择一个试验来代表不同的测试规模)。

TEST_SIZE=0.1,由于训练规模大,这应该有效,但方差较大(5 次试验的准确率在 16% 和 50% 之间变化)。

Epoch     0, Loss 0.021541, Targets [ 1.  0.  0.], Outputs [ 0.979  0.011  0.01 ], Inputs [ 0.086  0.052  0.08   0.062  0.101  0.093  0.107  0.058  0.108  0.08   0.084  0.115  0.104]
Epoch   100, Loss 0.001154, Targets [ 0.  0.  1.], Outputs [ 0.     0.001  0.999], Inputs [ 0.083  0.099  0.084  0.079  0.085  0.061  0.02   0.103  0.038  0.083  0.078  0.053  0.067]
Epoch   200, Loss 0.000015, Targets [ 0.  0.  1.], Outputs [ 0.  0.  1.], Inputs [ 0.076  0.092  0.087  0.107  0.077  0.063  0.02   0.13   0.054  0.106  0.054  0.051  0.086]
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
50.0% overall accuracy for validation set.

TEST_SIZE=0.5,这应该可以正常工作(其他两种情况之间的准确度)- 5 次试验的准确度在 92% 到 97% 之间变化出于某种原因

Epoch     0, Loss 0.547218, Targets [ 1.  0.  0.], Outputs [ 0.579  0.087  0.334], Inputs [ 0.106  0.08   0.142  0.133  0.129  0.115  0.127  0.13   0.12   0.068  0.123  0.126  0.11 ]
Epoch   100, Loss 0.002716, Targets [ 0.  1.  0.], Outputs [ 0.003  0.997  0.   ], Inputs [ 0.09   0.059  0.097  0.114  0.088  0.108  0.102  0.144  0.125  0.036  0.186  0.113  0.054]
Epoch   200, Loss 0.002874, Targets [ 0.  1.  0.], Outputs [ 0.003  0.997  0.   ], Inputs [ 0.102  0.067  0.088  0.109  0.088  0.097  0.091  0.088  0.092  0.056  0.113  0.141  0.089]
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
97.75280898876404% overall accuracy for validation set.

TEST_SIZE=0.9,由于训练样本小,这应该很难泛化 - 5 次试验的准确率在 38% 和 54% 之间变化。

Epoch     0, Loss 2.448474, Targets [ 0.  0.  1.], Outputs [ 0.707  0.206  0.086], Inputs [ 0.229  0.421  0.266  0.267  0.223  0.15   0.057  0.33   0.134  0.148  0.191  0.12   0.24 ]
Epoch   100, Loss 0.017506, Targets [ 1.  0.  0.], Outputs [ 0.983  0.017  0.   ], Inputs [ 0.252  0.162  0.274  0.255  0.241  0.275  0.314  0.175  0.278  0.135  0.286  0.36   0.281]
Epoch   200, Loss 0.001819, Targets [ 0.  0.  1.], Outputs [ 0.002  0.     0.998], Inputs [ 0.245  0.348  0.248  0.274  0.284  0.153  0.167  0.212  0.191  0.362  0.145  0.125  0.183]
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
64.59627329192547% overall accuracy for validation set.
关键功能sn-ps如下:

导入和拆分数据集

import numpy as np
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split


    def readInput(filename, delimiter, inputlen, outputlen, categories, test_size):
        def onehot(num, categories):
            arr = np.zeros(categories)
            arr[int(num[0])-1] = 1
            return arr
    
        with open(filename) as file:
            inputs = list()
            outputs = list()
            for line in file:
                assert(len(line.split(delimiter)) == inputlen+outputlen)
                outputs.append(onehot(list(map(lambda x: float(x), line.split(delimiter)))[:outputlen], categories))
                inputs.append(list(map(lambda x: float(x), line.split(delimiter)))[outputlen:outputlen+inputlen])
        inputs = np.array(inputs)
        outputs = np.array(outputs)
    
        inputs_train, inputs_val, outputs_train, outputs_val = train_test_split(inputs, outputs, test_size=test_size)
        assert len(inputs_train) > 0
        assert len(inputs_val) > 0
    
        return normalize(inputs_train, axis=0), outputs_train, normalize(inputs_val, axis=0), outputs_val

一些参数

import numpy as np
import helper

FILE_NAME = 'data2.csv'
DATA_DELIM = ','
ACTIVATION_FUNC = 'tanh'
TESTING_FREQ = 100
EPOCHS = 200
LEARNING_RATE = 0.2
TEST_SIZE = 0.9

INPUT_SIZE = 13
HIDDEN_LAYERS = [5]
OUTPUT_SIZE = 3

主程序流程

    def step(self, x, targets, lrate):
        self.forward_propagate(x)
        self.backpropagate_errors(targets)
        self.adjust_weights(x, lrate)

    def test(self, epoch, x, target):
        predictions = self.forward_propagate(x)
        print('Epoch %5i, Loss %2f, Targets %s, Outputs %s, Inputs %s' % (epoch, helper.crossentropy(target, predictions), target, predictions, x))

    def train(self, inputs, targets, epochs, testfreq, lrate):
        xindices = [i for i in range(len(inputs))]
        for epoch in range(epochs):
            np.random.shuffle(xindices)
            if epoch % testfreq == 0:
                self.test(epoch, inputs[xindices[0]], targets[xindices[0]])
            for i in xindices:
                self.step(inputs[i], targets[i], lrate)
        self.test(epochs, inputs[xindices[0]], targets[xindices[0]])

    def validate(self, inputs, targets):
        correct = 0
        targets = np.argmax(targets, axis=1)
        for i in range(len(inputs)):
            prediction = np.argmax(self.forward_propagate(inputs[i]))
            if prediction == targets[i]: correct += 1
            print('Target Class %s, Predicted Class %s' % (targets[i], prediction))
        print('%s%% overall accuracy for validation set.' % (correct/len(inputs)*100))


np.random.seed()

inputs_train, outputs_train, inputs_val, outputs_val = helper.readInput(FILE_NAME, DATA_DELIM, inputlen=INPUT_SIZE, outputlen=1, categories=OUTPUT_SIZE, test_size=TEST_SIZE)
nn = Classifier([INPUT_SIZE] + HIDDEN_LAYERS + [OUTPUT_SIZE], ACTIVATION_FUNC)

nn.train(inputs_train, outputs_train, EPOCHS, TESTING_FREQ, LEARNING_RATE)

nn.validate(inputs_val, outputs_val)

【问题讨论】:

80/20 分割并非在所有情况下都是最佳的。这取决于您的数据。再次测试您的假设几次会有所帮助,但这次将数据集打乱。 很遗憾,这样的问题并不是一个明确的答案,尤其是在无法访问您的数据的情况下。 Coldspeed,我提供了一个数据集(已编辑)。 Swailem95,数据集在每个 epoch 和(我相信)在拆分之前都会被打乱(参见scikit-learn.org/stable/modules/generated/…) @cᴏʟᴅsᴘᴇᴇᴅ 见上文(不确定是否正确编辑评论标签。 【参考方案1】:

1) 样本量非常小。您有 13 个维度,只有 178 个样本。由于您需要训练 5 层 NN 的参数,因此无论您如何拆分,都没有足够的数据。所以你的模型对于你拥有的数据量来说太复杂了,这会导致过度拟合。这意味着,您的模型不能很好地泛化,并且在一般情况下不会为您提供良好的结果,并且不会提供稳定的结果。

2) 训练数据集和测试数据集之间总会存在一些差异。在您的情况下,由于样本量小,您的测试和训练数据的统计数据之间的一致性大多是随机的。

3) 当您拆分 90-10 时,您的测试集只有 17 个样本。仅仅 17 次试验你无法获得太多价值。它几乎不能称为“统计”。尝试不同的拆分,你的结果也会改变(你已经看到了这种现象,正如我在上面提到的关于稳健性的那样)

4) 始终将您的分类器与随机分类器的性能进行比较。在您的 3 个班级的情况下,您至少应该获得超过 33% 的成绩。

5) 了解交叉验证和留一法。

【讨论】:

以上是关于为啥 50-50 的训练/测试拆分最适合使用此神经网络的 178 个观察数据集?的主要内容,如果未能解决你的问题,请参考以下文章

有人可以解释为啥我们在将数据拆分为训练和测试时使用 random_state 吗? [复制]

为啥在 Keras 中使用前馈神经网络进行单独的训练、验证和测试数据集可以获得 100% 的准确率?

根据观察名称将数据拆分为训练和使用 pandas 进行测试

python做BP神经网络,进行数据预测,训练的输入和输出值都存在负数,为啥预测值永远为正数?

验证集与测试集有啥区别?为啥要分训练集、验证集和测试集?

为啥我们使用正则化来训练神经网络?