[ML]简单的Normal Equation对数据点进行线性回归
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[ML]简单的Normal Equation对数据点进行线性回归相关的知识,希望对你有一定的参考价值。
注明:本文仅作为记录本人的日常学习历程而存在。
Normal Equation和上篇介绍的方法相比,简单许多。具体公式见吴恩达老师的coursera视频
1.
generate_data用来生成实验中所用到的数据,数据总体分布在斜率为10-30之间随机取值,截距为200-5000之间随机取值的直线上
compute函数用来计算出目标直线参数:
import numpy as np import matplotlib.pyplot as plt def compute(X,Y): return (X.T.dot(X))**(-1)*(X.T)*Y def generate_data(data_size): x = np.random.randint(-250,250,size=data_size) y = [] for i in range(data_size): y.append(x[i]*np.random.randint(10,30)+np.random.randint(200,5000)) return (x,y)
2.
进行计算。
data_size = 500 (xx,yy) = generate_data(data_size) #plt.plot(x,y,‘rx‘) x = [[1,xx[i]] for i in range(data_size)] X = np.matrix(x).reshape((data_size,2)) Y = np.matrix(yy).reshape((data_size,1)) theta = compute(X,Y) theta = theta.getA() print(theta)
3.
可视化:
result_x = np.linspace(-250,250,data_size) result_y = theta[1] * result_x + theta[0] plt.plot(result_x,result_y) plt.plot(xx,yy,‘rx‘)
4.最终结果:
以上是关于[ML]简单的Normal Equation对数据点进行线性回归的主要内容,如果未能解决你的问题,请参考以下文章
正规方程(Normal Equation)——对于线性回归问题的一种快速解法