tensorflow线性回归预测鲍鱼数据
Posted ywjfx
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow线性回归预测鲍鱼数据相关的知识,希望对你有一定的参考价值。
代码如下:
import tensorflow as tf import csv import numpy as np import matplotlib.pyplot as plt # 设置学习率 learning_rate = 0.01 # 设置训练次数 train_steps = 1000 #数据地址:http://archive.ics.uci.edu/ml/datasets/Abalone with open(‘./data/abalone.data‘) as file: reader = csv.reader(file) a, b = [], [] for item in reader: b.append(item[8]) del(item[8]) a.append(item) file.close() x_data = np.array(a) new_x_data = [] for i in x_data[:,0]: if i == ‘M‘: i = 1 elif i == ‘F‘: i = 2 elif i == ‘I‘: i = 3 new_x_data.append(i) new_data = np.array(new_x_data) x_data = np.delete(x_data,0,axis=1) print(x_data.shape) print(new_data.shape) x_data = np.column_stack((new_data,x_data)) #添加一列,将new_data添加到x_data中 print(x_data) print(x_data[:,0]) y_data = np.array(b) for i in range(len(x_data)): y_data[i] = float(y_data[i]) for j in range(len(x_data[i])): x_data[i][j] = float(x_data[i][j]) # 定义各影响因子的权重 weights = tf.Variable(np.ones([8,1]),dtype = tf.float32) x_data_ = tf.placeholder(tf.float32, [None, 8]) y_data_ = tf.placeholder(tf.float32, [None, 1]) bias = tf.Variable(1.0, dtype = tf.float32)#定义偏差值 # 构建模型为:y_model = w1X1 + w2X2 + w3X3 + w4X4 + w5X5 + w6X6 + w7X7 + w8X8 + bias y_model = tf.add(tf.matmul(x_data_ , weights), bias) # 定义损失函数 loss = tf.reduce_mean(tf.pow((y_model - y_data_), 2)) #训练目标为损失值最小,学习率为0.01 train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print("Start training!") lo = [] sample = np.arange(train_steps) for i in range(train_steps): for (x,y) in zip(x_data, y_data): z1 = x.reshape(1,8) z2 = y.reshape(1,1) sess.run(train_op, feed_dict = {x_data_ : z1, y_data_ : z2}) l = sess.run(loss, feed_dict = {x_data_ : z1, y_data_ : z2}) lo.append(l) print(weights.eval(sess)) print(bias.eval(sess)) # 绘制训练损失变化图 plt.plot(sample, lo, marker="*", linewidth=1, linestyle="--", color="red") plt.title("The variation of the loss") plt.xlabel("Sampling Point") plt.ylabel("Loss") plt.grid(True) plt.show()
以上是关于tensorflow线性回归预测鲍鱼数据的主要内容,如果未能解决你的问题,请参考以下文章
《用Python玩转数据》项目—线性回归分析入门之波士顿房价预测