tensorflow demo 手写数字识别
Posted _刘文凯_
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow demo 手写数字识别相关的知识,希望对你有一定的参考价值。
记录下使用tensorflow进行多分类任务,也就是识别0-9这10个数字
环境
python3.7
tensorflow1.15.0
代码
import tensorflow as tf
import numpy as np
from sklearn import datasets
from tensorflow.examples.tutorials.mnist import input_data
# 读取拟南芥数据
def read_infile():
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
return mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
# 输入、生成权重w, 偏差b
def w_biases_placeholder(n_dim, n_clasess):
X = tf.placeholder(tf.float32, [None, n_dim])
Y = tf.placeholder(tf.float32, [None,n_clasess])
w = tf.Variable(tf.random_normal([n_dim, n_clasess],stddev=0.01), name= 'w')
b = tf.Variable(tf.random_normal([n_clasess]), name='w')
return X, Y, w, b
def forward_pass(w,b,X):
return tf.matmul(X,w)+b
def multiclass_cost(cout,Y):
# 计算交叉熵损失
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=cout, labels=Y))
# 初始化变量
def init():
return tf.global_variables_initializer()
# 定义参数优化方法
def train_op(learning_rate, cost):
return tf.python.train.GradientDescentOptimizer(learning_rate).minimize(cost)
## 开始训练模型 ##
def train_model(learning_rate=0.01, epochs=1000):
trainx, trainy, testx, testy = read_infile()
X, Y, w, b = w_biases_placeholder(trainx.shape[1], trainy.shape[1])
out = forward_pass(w, b, X)
cost = multiclass_cost(out,Y)
op_train = train_op(learning_rate, cost)
init_ = init()
loss_trace = []
acc_trace = []
with tf.Session() as sess:
sess.run(init_)
for i in range(epochs):
sess.run(op_train, feed_dict=X: trainx, Y: trainy)
loss_ = sess.run(cost, feed_dict=X: trainx, Y: trainy)
acc_ = np.mean(np.argmax(sess.run(out, feed_dict=X: trainx, Y: trainy),axis=1) == np.argmax(trainy, axis=1))
loss_trace.append(loss_)
acc_trace.append(acc_)
if (i+1)%100 == 0 and (i+1)//100 >=1:
print('acc:',acc_)
loss_test = sess.run(cost, feed_dict=X: testx, Y: testy)
pred = np.argmax(sess.run(out, feed_dict=X: testx, Y: testy), axis=1)
acc_test = np.mean(pred == np.argmax(testy, axis=1))
print(loss_test) #
print(acc_test) # 得到acc准确率
import matplotlib.pyplot as plt
print('\\n')
print('True:', np.argmax(testy[0:10], axis=1))
print('Pred:', pred[0:10])
f, a = plt.subplots(1,10,figsize=(10, 2))
for i in range(10):
a[i].imshow(np.reshape(testx[i],(28,28)))
if __name__ == '__main__':
train_model()
以上是关于tensorflow demo 手写数字识别的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow1.x 代码实战系列:MNIST手写数字识别
Tensorflow暑期实践——基于单个神经元的手写数字识别(全部代码)