tensorflow训练线性回归模型

Posted maskerk

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow训练线性回归模型相关的知识,希望对你有一定的参考价值。

完整代码

import tensorflow as tf 
import matplotlib.pyplot as plt
import numpy as np

#样本数据
x_train = np.linspace(-1,1,300)[:,np.newaxis]
noise = np.random.normal(0, 0.1, x_train.shape)
y_train = x_train * 3 + noise + 0.8

#线性模型
W = tf.Variable([0.1],dtype = tf.float32)
b = tf.Variable([0.1],dtype = tf.float32)
x = tf.placeholder(tf.float32)
line_model = W * x + b

#损失模型
y = tf.placeholder(tf.float32)
loss = tf.reduce_sum(tf.square(line_model - y))

#创建优化器
optimizer = tf.train.GradientDescentOptimizer(0.001)
train = optimizer.minimize(loss)

#初始化变量
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# 绘制样本数据
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_train, y_train)
plt.ion()
plt.show()
plt.pause(3)


#训练100次
for i in range(100):
    #每隔10次打印1次成果
    if i % 10 == 0:
        print(i)
        print('W:%s  b:%s' % (sess.run(W),sess.run(b)))
        print('loss:%s' % (sess.run(loss,{x:x_train,y:y_train})))
    sess.run(train,{x:x_train,y:y_train})

print('---')
print('W:%s  b:%s' % (sess.run(W),sess.run(b)))
print('loss:%s' % (sess.run(loss,{x:x_train,y:y_train})))

样本训练数据分布如下

技术分享图片

输出结果如下

技术分享图片

结论

通过打印结果可以看到W已经非常接近初始给定的3,b也非常接近给定的0.8 (误差不可避免)

以上是关于tensorflow训练线性回归模型的主要内容,如果未能解决你的问题,请参考以下文章

线性回归详解(代码实现+理论证明)

如何使用 tensorflow 训练一个简单的非线性回归模型?

TensorFlow训练Logistic回归

如何使用 tensorflow 或 keras 重新训练具有新子集的线性回归模型?

单变量线性回归:TensorFlow 实战(理论篇)

Tensorflow之单变量线性回归问题的解决方法