Tensorflow之MNIST手写数字识别:分类问题

Posted lsm-boke

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow之MNIST手写数字识别:分类问题相关的知识,希望对你有一定的参考价值。

整体代码:

#数据读取
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

#定义待输入数据的占位符
#mnist中每张照片共有28*28=784个像素点
x = tf.placeholder(tf.float32,[None,784],name="X")

#0-9一共10个数字=>10个类别
y = tf.placeholder(tf.float32,[None,10],name="Y")

#定义模型变量
#以正态分布的随机数初始化权重W,以常数0初始化偏置b
#在神经网络中,权值W的初始值通常设为正态分布的随机数,偏置项b的初始值通常也设置为正态分布的随机数或常数。
W = tf.Variable(tf.random_normal([784,10],name="W"))
b = tf.Variable(tf.zeros([10]),name="b")

#用单个神经元构建神经网络
forward=tf.matmul(x,W) + b   #前向计算

#结果分类
#当我们处理多分类任务的时候,通常需要使用Softmax Regression模型。Softmax会对每一类别估算出一个概率。
#工作原理:将判定为某一类的特征相加,然后将这些特征转化为判定是这一类的概率
pred = tf.nn.softmax(forward)     #Softmax分类

#设置训练参数
train_epochs = 120     #训练轮数
batch_size = 120      #单次训练样本数(批次大小)
total_batch = int(mnist.train.num_examples/batch_size)              #一轮训练有多少批次
display_step = 1   #显示粒度
learning_rate = 0.01             #学习率 

#概率估算值需要将预测输出值控制在[0,1]区间内。二元分类问题的目标是正确预测两个可能标签中的一个
#逻辑回归可以用于处理这类问题。二元逻辑回归的损失函数一般采用对数损失函数
#多元分类:逻辑回归可生成介于0到1.0之间的小数。Softmax将这一想法延伸到多类别领域。
#在多类别问题中,Softmax会为每个类别分配一个用小数表示的概率。这些用小数表示的概率相加之和必须是1.0

#交叉熵损失函数:交叉熵是一个信息论的概念,它原来是用来估算平均编码长度的。
#交叉熵刻画的是两个概率分布之间的距离,p代表正确答案,q代表的预测值,交叉熵越小,两个概率的分布越接近
#定义损失函数
loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))    #交叉熵

#选择优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)     #梯度下降优化器

#定义准确率
# 检查预测类别tf.argmax(pred,1)与实际类别tf.argmax(y,1)的匹配情况
#argmax()将数组中最大值的下标取出来
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))

#准确率,将布尔值转化为浮点数,并计算平均值    tf.cast()将布尔值投射成浮点数
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#声明会话,初始化变量
sess = tf.Session()
init = tf.global_variables_initializer()   #变量初始化
sess.run(init)

#训练模型
for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size)  #读取批次数据
        sess.run(optimizer,feed_dict={x:xs,y:ys})   #执行批次训练
        
    #total_batch个批次训练完成后,使用验证数据计算误差与准确率,验证集没有分批
    loss,acc = sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
    
    #打印训练过程中的详细信息
    if (epoch+1) % display_step == 0:
        print("Train Epoch:",%02d%(epoch+1),"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
        
print("Train Finished!")
        
#评估模型
#完成训练后,在测试集上评估模型的准确率
accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Test Accuracy:",accu_test)
#完成训练后,在验证集上评估模型的准确率
accu_validation = sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
print("Test Accuracy:",accu_validation)
#完成训练后,在训练集上评估模型的准确率
accu_train = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
print("Test Accuracy:",accu_train)

#应用模型
#在建立模型并进行训练后,若认为准确率可以接受,则可以使用此模型进行预测
#由于pred预测结果是one_hot编码格式,所以需要转换成0~9数字
prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})

#查看预测结果中的前10项
prediction_result[0:10]

#定义可视化函数
def plot_images_labels_prediction(images,labels,prediction,index,num=10):  #参数: 图形列表,标签列表,预测值列表,从第index个开始显示,缺省一次显示10幅
    fig = plt.gcf()             #获取当前图表,Get Current Figure
    fig.set_size_inches(10,12)    #1英寸等于2.45cm
    if num > 25 :      #最多显示25个子图
        num = 25
    for i in range(0,num):
        ax = plt.subplot(5,5,i+1)   #获取当前要处理的子图
        ax.imshow(np.reshape(images[index],(28,28)), cmap = binary)              #显示第index个图像
        title = "labels="+str(np.argmax(labels[index]))              #构建该图上要显示的title信息
        if len(prediction)>0:
            title += ",predict="+str(prediction[index])
            
        ax.set_title(title,fontsize=10)    #显示图上的title信息
        ax.set_xticks([])           #不显示坐标轴
        ax.set_yticks([])
        index += 1
    plt.show()
#可视化预测结果
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,10,10)

 

以上是关于Tensorflow之MNIST手写数字识别:分类问题的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow 之 手写数字识别MNIST

TensorFlow 入门之手写识别(MNIST) softmax算法

手写数字识别——基于全连接层和MNIST数据集

Tensorflow项目实战一:MNIST手写数字识别

TensorFlow MNIST 手写数字识别之过拟合

tensorflow实现mnist手写数字识别