为啥 50-50 的训练/测试拆分最适合使用此神经网络的 178 个观察数据集?
Posted
技术标签:
【中文标题】为啥 50-50 的训练/测试拆分最适合使用此神经网络的 178 个观察数据集?【英文标题】:Why does a 50-50 train/test split work best for a data-set of 178 observations with this neural network?为什么 50-50 的训练/测试拆分最适合使用此神经网络的 178 个观察数据集? 【发布时间】:2018-03-30 21:23:56 【问题描述】:从我所读到的内容看来,大约 80% 的训练和 20% 的 va 验证数据接近最优。随着测试数据集大小的增加,验证结果的方差应该会降低,但代价是训练效率降低(验证准确度降低)。
因此,我对以下结果感到困惑,这些结果似乎与TEST_SIZE=0.5
显示出最佳准确性和低方差(每个试验运行多次,并选择一个试验来代表不同的测试规模)。
TEST_SIZE=0.1
,由于训练规模大,这应该有效,但方差较大(5 次试验的准确率在 16% 和 50% 之间变化)。
Epoch 0, Loss 0.021541, Targets [ 1. 0. 0.], Outputs [ 0.979 0.011 0.01 ], Inputs [ 0.086 0.052 0.08 0.062 0.101 0.093 0.107 0.058 0.108 0.08 0.084 0.115 0.104]
Epoch 100, Loss 0.001154, Targets [ 0. 0. 1.], Outputs [ 0. 0.001 0.999], Inputs [ 0.083 0.099 0.084 0.079 0.085 0.061 0.02 0.103 0.038 0.083 0.078 0.053 0.067]
Epoch 200, Loss 0.000015, Targets [ 0. 0. 1.], Outputs [ 0. 0. 1.], Inputs [ 0.076 0.092 0.087 0.107 0.077 0.063 0.02 0.13 0.054 0.106 0.054 0.051 0.086]
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
50.0% overall accuracy for validation set.
TEST_SIZE=0.5
,这应该可以正常工作(其他两种情况之间的准确度)- 5 次试验的准确度在 92% 到 97% 之间变化出于某种原因。
Epoch 0, Loss 0.547218, Targets [ 1. 0. 0.], Outputs [ 0.579 0.087 0.334], Inputs [ 0.106 0.08 0.142 0.133 0.129 0.115 0.127 0.13 0.12 0.068 0.123 0.126 0.11 ]
Epoch 100, Loss 0.002716, Targets [ 0. 1. 0.], Outputs [ 0.003 0.997 0. ], Inputs [ 0.09 0.059 0.097 0.114 0.088 0.108 0.102 0.144 0.125 0.036 0.186 0.113 0.054]
Epoch 200, Loss 0.002874, Targets [ 0. 1. 0.], Outputs [ 0.003 0.997 0. ], Inputs [ 0.102 0.067 0.088 0.109 0.088 0.097 0.091 0.088 0.092 0.056 0.113 0.141 0.089]
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 0
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 0
Target Class 1, Predicted Class 1
97.75280898876404% overall accuracy for validation set.
TEST_SIZE=0.9
,由于训练样本小,这应该很难泛化 - 5 次试验的准确率在 38% 和 54% 之间变化。
Epoch 0, Loss 2.448474, Targets [ 0. 0. 1.], Outputs [ 0.707 0.206 0.086], Inputs [ 0.229 0.421 0.266 0.267 0.223 0.15 0.057 0.33 0.134 0.148 0.191 0.12 0.24 ]
Epoch 100, Loss 0.017506, Targets [ 1. 0. 0.], Outputs [ 0.983 0.017 0. ], Inputs [ 0.252 0.162 0.274 0.255 0.241 0.275 0.314 0.175 0.278 0.135 0.286 0.36 0.281]
Epoch 200, Loss 0.001819, Targets [ 0. 0. 1.], Outputs [ 0.002 0. 0.998], Inputs [ 0.245 0.348 0.248 0.274 0.284 0.153 0.167 0.212 0.191 0.362 0.145 0.125 0.183]
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 2, Predicted Class 2
Target Class 0, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 0, Predicted Class 1
Target Class 1, Predicted Class 1
Target Class 2, Predicted Class 2
64.59627329192547% overall accuracy for validation set.
关键功能sn-ps如下:
导入和拆分数据集
import numpy as np
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
def readInput(filename, delimiter, inputlen, outputlen, categories, test_size):
def onehot(num, categories):
arr = np.zeros(categories)
arr[int(num[0])-1] = 1
return arr
with open(filename) as file:
inputs = list()
outputs = list()
for line in file:
assert(len(line.split(delimiter)) == inputlen+outputlen)
outputs.append(onehot(list(map(lambda x: float(x), line.split(delimiter)))[:outputlen], categories))
inputs.append(list(map(lambda x: float(x), line.split(delimiter)))[outputlen:outputlen+inputlen])
inputs = np.array(inputs)
outputs = np.array(outputs)
inputs_train, inputs_val, outputs_train, outputs_val = train_test_split(inputs, outputs, test_size=test_size)
assert len(inputs_train) > 0
assert len(inputs_val) > 0
return normalize(inputs_train, axis=0), outputs_train, normalize(inputs_val, axis=0), outputs_val
一些参数
import numpy as np
import helper
FILE_NAME = 'data2.csv'
DATA_DELIM = ','
ACTIVATION_FUNC = 'tanh'
TESTING_FREQ = 100
EPOCHS = 200
LEARNING_RATE = 0.2
TEST_SIZE = 0.9
INPUT_SIZE = 13
HIDDEN_LAYERS = [5]
OUTPUT_SIZE = 3
主程序流程
def step(self, x, targets, lrate):
self.forward_propagate(x)
self.backpropagate_errors(targets)
self.adjust_weights(x, lrate)
def test(self, epoch, x, target):
predictions = self.forward_propagate(x)
print('Epoch %5i, Loss %2f, Targets %s, Outputs %s, Inputs %s' % (epoch, helper.crossentropy(target, predictions), target, predictions, x))
def train(self, inputs, targets, epochs, testfreq, lrate):
xindices = [i for i in range(len(inputs))]
for epoch in range(epochs):
np.random.shuffle(xindices)
if epoch % testfreq == 0:
self.test(epoch, inputs[xindices[0]], targets[xindices[0]])
for i in xindices:
self.step(inputs[i], targets[i], lrate)
self.test(epochs, inputs[xindices[0]], targets[xindices[0]])
def validate(self, inputs, targets):
correct = 0
targets = np.argmax(targets, axis=1)
for i in range(len(inputs)):
prediction = np.argmax(self.forward_propagate(inputs[i]))
if prediction == targets[i]: correct += 1
print('Target Class %s, Predicted Class %s' % (targets[i], prediction))
print('%s%% overall accuracy for validation set.' % (correct/len(inputs)*100))
np.random.seed()
inputs_train, outputs_train, inputs_val, outputs_val = helper.readInput(FILE_NAME, DATA_DELIM, inputlen=INPUT_SIZE, outputlen=1, categories=OUTPUT_SIZE, test_size=TEST_SIZE)
nn = Classifier([INPUT_SIZE] + HIDDEN_LAYERS + [OUTPUT_SIZE], ACTIVATION_FUNC)
nn.train(inputs_train, outputs_train, EPOCHS, TESTING_FREQ, LEARNING_RATE)
nn.validate(inputs_val, outputs_val)
【问题讨论】:
80/20 分割并非在所有情况下都是最佳的。这取决于您的数据。再次测试您的假设几次会有所帮助,但这次将数据集打乱。 很遗憾,这样的问题并不是一个明确的答案,尤其是在无法访问您的数据的情况下。 Coldspeed,我提供了一个数据集(已编辑)。 Swailem95,数据集在每个 epoch 和(我相信)在拆分之前都会被打乱(参见scikit-learn.org/stable/modules/generated/…) @cᴏʟᴅsᴘᴇᴇᴅ 见上文(不确定是否正确编辑评论标签。 【参考方案1】:1) 样本量非常小。您有 13 个维度,只有 178 个样本。由于您需要训练 5 层 NN 的参数,因此无论您如何拆分,都没有足够的数据。所以你的模型对于你拥有的数据量来说太复杂了,这会导致过度拟合。这意味着,您的模型不能很好地泛化,并且在一般情况下不会为您提供良好的结果,并且不会提供稳定的结果。
2) 训练数据集和测试数据集之间总会存在一些差异。在您的情况下,由于样本量小,您的测试和训练数据的统计数据之间的一致性大多是随机的。
3) 当您拆分 90-10 时,您的测试集只有 17 个样本。仅仅 17 次试验你无法获得太多价值。它几乎不能称为“统计”。尝试不同的拆分,你的结果也会改变(你已经看到了这种现象,正如我在上面提到的关于稳健性的那样)
4) 始终将您的分类器与随机分类器的性能进行比较。在您的 3 个班级的情况下,您至少应该获得超过 33% 的成绩。
5) 了解交叉验证和留一法。
【讨论】:
以上是关于为啥 50-50 的训练/测试拆分最适合使用此神经网络的 178 个观察数据集?的主要内容,如果未能解决你的问题,请参考以下文章
有人可以解释为啥我们在将数据拆分为训练和测试时使用 random_state 吗? [复制]
为啥在 Keras 中使用前馈神经网络进行单独的训练、验证和测试数据集可以获得 100% 的准确率?