基于tensorflow的简单线性回归模型
Posted Sugars_DJ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于tensorflow的简单线性回归模型相关的知识,希望对你有一定的参考价值。
#!/usr/local/bin/python3
##ljj [1]
##linear regression model
import tensorflow as tf
import matplotlib.pyplot as plt
#训练样本,随手写的
x_ = [11,14,22,29,32,40,44,55,59,60,69,77]
y_res = [123,135,155,167,177,189,200,240,250,255,277,298]
#初始化定义w和b,都为1,这里折腾了一会,主要因为tf.ones的参数
w = tf.Variable(tf.ones([1]),dtype="float32")
b = tf.Variable(tf.ones([1]),dtype="float32")
y = tf.placeholder(tf.float32)
x = tf.placeholder(tf.float32)
with tf.Session() as sess:
#定义线性模型
y_predict = w*x+b
#平方误差作为损失函数
loss = tf.reduce_mean(tf.square(y-y_predict))
#配置训练优化器和学习速率
train = tf.train.AdamOptimizer(0.03).minimize(loss)
sess.run(tf.global_variables_initializer())
for j in range(1000):
for i in range(len(x_)):
# train.run(feed_dict={x:x_[i], y:y_res[i]})
#feed训练,并输出w和b
w_,b_,_= sess.run([w,b,train],feed_dict={x:x_[i], y:y_res[i]})
print(w_,b_)
print(\'final result : \')
print(w_,b_)
plt.plot(x_,y_res,\'.\')
plt.plot(x_,x_*w_+b_,\'-\')
plt.show()
主机环境:MacbookPro,tensoflow版本1.4,pyhton3.5
输出结果:
final result :
[ 2.65540743] [ 91.92604065]
-------以上输出分别是拟合出的Weight,Bias值。不同版本的tensorlfow,拟合的线可能会略有差异,稍微调调参就可以拟合的不错。
以上是关于基于tensorflow的简单线性回归模型的主要内容,如果未能解决你的问题,请参考以下文章