为啥神经网络的准确性不好?
Posted
技术标签:
【中文标题】为啥神经网络的准确性不好?【英文标题】:Why bad accuracy with neural network?为什么神经网络的准确性不好? 【发布时间】:2015-09-03 09:31:52 【问题描述】:我有一个包含 43 个示例(数据点)和 70'000 个特征的数据集,这意味着我的数据集矩阵是 (43 x 70'000)。标签包含 4 个不同的值 (1-4),即有 4 个类。
现在,我已经使用深度信念网络/神经网络进行了分类,但通过留一法交叉验证,我得到的准确率只有 25% 左右(机会水平)。如果我使用 kNN、SVM 等,我的准确率将超过 80%。
我已经使用了用于 Matlab 的 DeepLearnToolbox (https://github.com/rasmusbergpalm/DeepLearnToolbox),并且刚刚从工具箱的自述文件中改编了深度信念网络示例。我尝试了不同数量的隐藏层(1-3)和不同数量的隐藏节点(100、500,...)以及不同的学习率、动量等,但准确性仍然很差。特征向量被缩放到 [0,1] 范围,因为这是工具箱所需要的。
详细我做了以下代码(只显示了一次交叉验证):
% Indices of training and test set
train = training(c,m);
test = ~train;
% Train DBN
opts = [];
dbn = [];
dbn.sizes = [500 500 500];
opts.numepochs = 50;
opts.batchsize = 1;
opts.momentum = 0.001;
opts.alpha = 0.15;
dbn = dbnsetup(dbn, feature_vectors_std(train,:), opts);
dbn = dbntrain(dbn, feature_vectors_std(train,:), opts);
%unfold dbn to nn
nn = dbnunfoldtonn(dbn, 4);
nn.activation_function = 'sigm';
nn.learningRate = 0.15;
nn.momentum = 0.001;
%train nn
opts.numepochs = 50;
opts.batchsize = 1;
train_labels = labels(train);
nClass = length(unique(train_labels));
L = zeros(length(train_labels),nClass);
for i = 1:nClass
L(train_labels == i,i) = 1;
end
nn = nntrain(nn, feature_vectors_std(train,:), L, opts);
class = nnpredict(nn, feature_vectors_std(test,:));
feature_vectors_std 是 (43 x 70'000) 矩阵,其值缩放为 [0,1]。
有人能推断出为什么我的准确率这么差吗?
【问题讨论】:
【参考方案1】:因为您的特征比数据集中的示例多得多。换句话说:你有大量的权重,你需要估计所有的权重,但你不能,因为结构如此庞大的神经网络无法在如此小的数据集上很好地泛化,你需要更多的数据来学习如此大量的隐藏权重(实际上 NN 可能会记住你的训练集,但不能推断出它是测试集的“知识”)。同时,SVM 和 kNN 等简单方法的 80% 准确率表明您可以使用更简单的规则来描述您的数据,因为例如 SVM 将只有 70k 权重(而不是 70kfirst_layer_size + first_layer_size second_layer_size + ... in NN),kNN 根本不会使用权重。
复杂模型不是灵丹妙药,您尝试拟合的模型越复杂 - 您需要的数据就越多。
【讨论】:
【参考方案2】:很明显,您的数据集比网络的复杂性太小。 reference from there
一个神经网络的复杂度可以用数字来表示 的参数。在深度神经网络的情况下,这个数字可以是 在数百万、数千万范围内,在某些情况下甚至 亿万。让我们称这个号码为 P。既然你想成为 确定模型的泛化能力,这是一个很好的经验法则 数据点的数量至少为 P*P。
虽然 KNN 和 SVM 更简单,但它们不需要那么多数据。 这样他们才能更好地工作。
【讨论】:
以上是关于为啥神经网络的准确性不好?的主要内容,如果未能解决你的问题,请参考以下文章
为啥在 Keras 中使用前馈神经网络进行单独的训练、验证和测试数据集可以获得 100% 的准确率?