tensorflow的hello world
Posted z-bear
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow的hello world相关的知识,希望对你有一定的参考价值。
import tensorflow as tf;
from tensorflow.examples.tutorials.mnist import input_data
##定义网络结构
input_nodes = 784
output_nodes = 10
layer1_nodes = 500
#定义超参数
#自动设置学习率
learning_rate_base= 0.8;
learning_decay = 0.99 ;
decay_step=100 ;
#滑动平均
moving_average__decay = 0.99
regularizer_rate = 0.01;
train_step=30000
batch_size= 100
def inference(tensor1,weight1,bias1,weight2,bias2,average_class=None):
if(average_class==None):
layer1=tf.nn.relu( tf.matmul(tensor1,weight1)+ bias1 )
return tf.matmul( layer1,weight2 ) + bias2
else:
layer1 = tf.nn.relu(tf.matmul(tensor1, average_class.average(weight1)) + average_class.average(bias1))
return tf.matmul(layer1, average_class.average(weight2) ) + average_class.average(bias2)
def get_weight(shape):
weight=tf.Variable(tf.truncated_normal(shape=shape,stddev=0.1),tf.float32)
tf.add_to_collection(‘losses‘, tf.contrib.layers.l2_regularizer(regularizer_rate)(weight))
return weight
def get_bias(shape):
return tf.Variable(tf.zeros(shape))
def train(mnist):
#定义输入输出
train_x=tf.placeholder(tf.float32,shape=[None,input_nodes],name=‘train_x‘)
train_y=tf.placeholder(tf.float32,shape=[None,output_nodes],name=‘train_y‘ )
weight1=get_weight( [input_nodes,layer1_nodes] )
bias1 =get_bias([layer1_nodes])
weight2=get_weight([layer1_nodes,output_nodes]);
bias2 =get_bias([output_nodes])
#定义学习率
global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(learning_rate_base, global_step, decay_step, learning_decay,staircase=True)
#定义损失、优化器
results=inference(train_x,weight1,bias1,weight2,bias2,None)
ce= tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=train_y,labels=tf.argmax( results) ) )
loss=ce+tf.add_n( tf.get_collection(‘losses‘) )
optimizer= tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step);
#定义滑动平均
ema = tf.train.ExponentialMovingAverage(moving_average__decay, global_step);
maintain_average_op = ema.apply( tf.trainable_variables())
with tf.control_dependencies([optimizer,maintain_average_op]):
train_op=tf.no_op(name=‘train‘)
#预测准确率
average_y=inference(train_x,weight1,bias1,weight2,bias2,ema);
correction_prediction = tf.equal( tf.argmax( average_y,1 ) ,tf.argmax(train_y,1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction,tf.float32));
with tf.Session() as sess:
tf.global_variables_initializer().run()
validate_feed={train_x:mnist.validation.images,train_y:mnist.validation.labels}
test_feed ={train_x:mnist.test.images,train_y:mnist.test.labels}
#迭代训练
for i in range(train_step):
if(i%1000 == 0 ):
validate_acc=sess.run(accuracy,feed_dict=validate_feed);
print(‘After %d training steps,using aaverage model is %g ‘%(i,validate_acc))
xt,yt=mnist.train.next_batch(batch_size);
sess.run(train_op,feed_dict={ train_x :xt,train_y:yt});
test_acc=sess.run(accuracy,feed_dict=test_feed)
print(‘accuracy is %g‘%(test_acc));
def main():
mnist= input_data.read_data_sets(‘./MNIST_data‘,one_hot=True)
train(mnist);
if __name__ == ‘__main__‘:
main()
以上是关于tensorflow的hello world的主要内容,如果未能解决你的问题,请参考以下文章
[Tensorflow系列-3]:Tensorflow基础 - Hello World程序与张量(Tensor)概述