Tensorflow MNIST 教程 - 测试精度非常低

Posted

技术标签:

【中文标题】Tensorflow MNIST 教程 - 测试精度非常低【英文标题】:Tensorflow MNIST tutorial - Test Accuracy very low 【发布时间】:2017-09-20 15:49:36 【问题描述】:

我从 tensorflow 开始,一直遵循这个标准MNIST tutorial。

但是,与预期的 92% 准确度相比,在训练集和测试集上获得的准确度并未超过 67%。 我熟悉 softmax 和多项式回归,并且使用从头开始的 python 实现以及使用sklearn.linear_model.LogisticRegression 获得了超过 94% 的结果。

我曾尝试过使用 CIFAR-10 数据集,但在这种情况下,准确度太低,只有 10% 左右,这等于随机分配类。这让我怀疑我是否安装了 tensorflow,但我对此不确定。

这里是my implementation of Tensorflow MNIST tutorial。我会请求是否有人可以看看我的实现。

【问题讨论】:

【参考方案1】:

您构建了图表,指定了损失函数,并创建了优化器(这是正确的)。问题是您只使用了一次优化器:

sess_tf.run(train_step, feed_dict=x: train_images_reshaped[0:1000], y_: train_labels[0:1000])

所以基本上你只运行一次梯度下降。很明显,你不可能在朝着正确方向迈出一小步后就快速收敛。你需要做一些事情:

for _ in xrange(many_steps):
  X, Y = get_a_new_batch_from(mnist_data)
  sess_tf.run(train_step, feed_dict=x: X, y_: Y)

如果您无法弄清楚如何修改我的伪代码,请查阅教程,因为根据我的记忆,他们很好地涵盖了这一点。

【讨论】:

感谢您指出。作为教程中的批量梯度下降,不知何故,我已经确定我不需要迭代并完全跳过了收敛概念。已按预期进行更改。【参考方案2】:
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

W 的初始化可能会导致您的网络除了随机猜测之外什么都学不到。因为 grad 将为零,而反向传播实际上根本不起作用。

您最好使用tf.Variable(tf.truncated_normal([784, 10], mean=0.0, stddev=0.01)) 初始化W,更多信息请参见https://www.tensorflow.org/api_docs/python/tf/truncated_normal。

【讨论】:

我尝试过这种方法,但效果不佳,@Salvador Dali 已确定原因。【参考方案3】:

不确定这在 2018 年 6 月是否仍然适用,但 MNIST beginner tutorial 不再与 example code on Github 匹配。如果您下载并运行示例代码,它确实会为您提供建议的 92% 准确度。

在学习本教程时,我发现有两个问题:

1) 不小心调用了两次softmax

教程首先告诉你定义y如下:

y = tf.nn.softmax(tf.matmul(x, W) + b)

但后来建议您使用tf.nn.softmax_cross_entropy_with_logits 定义交叉熵,这样很容易意外地执行以下操作:

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)

这会将您的 logits (tf.matmul(x, W) + b) 通过 softmax 发送两次,这导致我的准确率停留在 67%。

但是我注意到,即使修复了这个问题,我仍然只能达到非常不稳定的 80-90% 准确度,这导致我进入下一个问题:

2) tf.nn.softmax_cross_entropy_with_logits() 已弃用

他们还没有更新教程,但是tf.nn.softmax_cross_entropy_with_logits page 表示该功能已被弃用。

在 Github 上的示例代码中,他们已将其替换为 tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)

但是,您不能只是将函数换出 - 示例代码还会更改许多其他行的维度。

我对第一次这样做的人的建议是从 Github 下载当前工作的示例代码,并尝试将其与教程概念相匹配,而不是按字面意思理解说明。希望他们能及时更新!

【讨论】:

以上是关于Tensorflow MNIST 教程 - 测试精度非常低的主要内容,如果未能解决你的问题,请参考以下文章

python MNIST使用批量标准化 - TensorFlow教程

tensorflow的MNIST教程

Tensorflow学习教程------普通神经网络对mnist数据集分类

教程 | 使用MNIST数据集,在TensorFlow上实现基础LSTM网络

Tensorflow的MNIST进阶教程CNN网络参数理解

TensorFlow力学101笔记[4]