Tensorflow MNIST 教程 - 测试精度非常低
Posted
技术标签:
【中文标题】Tensorflow MNIST 教程 - 测试精度非常低【英文标题】:Tensorflow MNIST tutorial - Test Accuracy very low 【发布时间】:2017-09-20 15:49:36 【问题描述】:我从 tensorflow 开始,一直遵循这个标准MNIST tutorial。
但是,与预期的 92% 准确度相比,在训练集和测试集上获得的准确度并未超过 67%。 我熟悉 softmax 和多项式回归,并且使用从头开始的 python 实现以及使用sklearn.linear_model.LogisticRegression 获得了超过 94% 的结果。
我曾尝试过使用 CIFAR-10 数据集,但在这种情况下,准确度太低,只有 10% 左右,这等于随机分配类。这让我怀疑我是否安装了 tensorflow,但我对此不确定。
这里是my implementation of Tensorflow MNIST tutorial。我会请求是否有人可以看看我的实现。
【问题讨论】:
【参考方案1】:您构建了图表,指定了损失函数,并创建了优化器(这是正确的)。问题是您只使用了一次优化器:
sess_tf.run(train_step, feed_dict=x: train_images_reshaped[0:1000], y_: train_labels[0:1000])
所以基本上你只运行一次梯度下降。很明显,你不可能在朝着正确方向迈出一小步后就快速收敛。你需要做一些事情:
for _ in xrange(many_steps):
X, Y = get_a_new_batch_from(mnist_data)
sess_tf.run(train_step, feed_dict=x: X, y_: Y)
如果您无法弄清楚如何修改我的伪代码,请查阅教程,因为根据我的记忆,他们很好地涵盖了这一点。
【讨论】:
感谢您指出。作为教程中的批量梯度下降,不知何故,我已经确定我不需要迭代并完全跳过了收敛概念。已按预期进行更改。【参考方案2】:W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
W
的初始化可能会导致您的网络除了随机猜测之外什么都学不到。因为 grad 将为零,而反向传播实际上根本不起作用。
您最好使用tf.Variable(tf.truncated_normal([784, 10], mean=0.0, stddev=0.01))
初始化W
,更多信息请参见https://www.tensorflow.org/api_docs/python/tf/truncated_normal。
【讨论】:
我尝试过这种方法,但效果不佳,@Salvador Dali 已确定原因。【参考方案3】:不确定这在 2018 年 6 月是否仍然适用,但 MNIST beginner tutorial 不再与 example code on Github 匹配。如果您下载并运行示例代码,它确实会为您提供建议的 92% 准确度。
在学习本教程时,我发现有两个问题:
1) 不小心调用了两次softmax
教程首先告诉你定义y如下:
y = tf.nn.softmax(tf.matmul(x, W) + b)
但后来建议您使用tf.nn.softmax_cross_entropy_with_logits
定义交叉熵,这样很容易意外地执行以下操作:
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)
这会将您的 logits (tf.matmul(x, W) + b
) 通过 softmax 发送两次,这导致我的准确率停留在 67%。
但是我注意到,即使修复了这个问题,我仍然只能达到非常不稳定的 80-90% 准确度,这导致我进入下一个问题:
2) tf.nn.softmax_cross_entropy_with_logits() 已弃用
他们还没有更新教程,但是tf.nn.softmax_cross_entropy_with_logits page 表示该功能已被弃用。
在 Github 上的示例代码中,他们已将其替换为 tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
。
但是,您不能只是将函数换出 - 示例代码还会更改许多其他行的维度。
我对第一次这样做的人的建议是从 Github 下载当前工作的示例代码,并尝试将其与教程概念相匹配,而不是按字面意思理解说明。希望他们能及时更新!
【讨论】:
以上是关于Tensorflow MNIST 教程 - 测试精度非常低的主要内容,如果未能解决你的问题,请参考以下文章
python MNIST使用批量标准化 - TensorFlow教程
Tensorflow学习教程------普通神经网络对mnist数据集分类