机器学习--线性回归的实践
Posted dinghing
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了机器学习--线性回归的实践相关的知识,希望对你有一定的参考价值。
1.鉴于之前提到的房价的问题,使用线性回归该如何解决呢?
首先我们假设有如下的数据方便计算机进行学习
面积 | 卧室 | 价格 |
---|---|---|
2140 | 3 | 400 |
1600 | 3 | 330 |
2400 | 3 | 369 |
1416 | 2 | 232 |
... | ... | ... |
根据之前的演算过程(使房价与面积和卧室数目线性相关):
hθ(x)=θ0 +θ1x1 +θ2x2
θ为计算时的权重,x1为房间面积,x2为我是数目。
为了降低计算的模糊程度,将hθ(x)变成h(x)来进行计算,这时计算公式为:
n为学习次数。
2. 有了相关数据之后就要开始训练算法了(ex1a_linreg.m)
1 <span style="font-size:24px;">% 2 %This exercise uses a data from the UCI repository: 3 % Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository 4 % http://archive.ics.uci.edu/ml 5 % Irvine, CA: University of California, School of Information and Computer Science. 6 % 7 %Data created by: 8 % Harrison, D. and Rubinfeld, D.L. 9 % \'\'Hedonic prices and the demand for clean air\'\' 10 % J. Environ. Economics & Management, vol.5, 81-102, 1978. 11 % 12 addpath ../common 13 addpath ../common/minFunc_2012/minFunc 14 addpath ../common/minFunc_2012/minFunc/compiled 15 16 % Load housing data from file. 17 data = load(\'housing.data\'); % housing data 506x14 18 data=data\'; % put examples in columns 14x506 一般这里将每一个样本放在每一列 19 20 % Include a row of 1s as an additional intercept feature. 21 data = [ ones(1,size(data,2)); data ]; % 15x506 增加intercept term 22 23 % Shuffle examples. 乱序 目的在于之后能够随机选取training set和test sets 24 data = data(:, randperm(size(data,2))); %randperm(n)用于随机生成1到n的排列 25 26 % Split into train and test sets 27 % The last row of \'data\' is the median home price. 28 train.X = data(1:end-1,1:400); %选择前400个样本来训练,后面的样本来做测试 29 train.y = data(end,1:400); 30 31 test.X = data(1:end-1,401:end); 32 test.y = data(end,401:end); 33 34 m=size(train.X,2); %训练样本数量 35 n=size(train.X,1); %每个样本的变量个数 36 37 % Initialize the coefficient vector theta to random values. 38 theta = rand(n,1); %随机生成初始theta 每个值在(0,1)之间 39 40 % Run the minFunc optimizer with linear_regression.m as the objective. 41 % 42 % TODO: Implement the linear regression objective and gradient computations 43 % in linear_regression.m 44 % 45 tic; %Start a stopwatch timer. 开始计时 46 options = struct(\'MaxIter\', 200); 47 theta = minFunc(@linear_regression, theta, options, train.X, train.y); 48 fprintf(\'Optimization took %f seconds.\\n\', toc); %toc Read the stopwatch timer 49 50 % Run minFunc with linear_regression_vec.m as the objective. 51 % 52 % TODO: Implement linear regression in linear_regression_vec.m 53 % using MATLAB\'s vectorization features to speed up your code. 54 % Compare the running time for your linear_regression.m and 55 % linear_regression_vec.m implementations. 56 % 57 % Uncomment the lines below to run your vectorized code. 58 %Re-initialize parameters 59 %theta = rand(n,1); 60 %tic; 61 %theta = minFunc(@linear_regression_vec, theta, options, train.X, train.y); 62 %fprintf(\'Optimization took %f seconds.\\n\', toc); 63 64 % Plot predicted prices and actual prices from training set. 65 actual_prices = train.y; 66 predicted_prices = theta\'*train.X; 67 68 % Print out root-mean-squared (RMS) training error.平方根误差 69 train_rms=sqrt(mean((predicted_prices - actual_prices).^2)); 70 fprintf(\'RMS training error: %f\\n\', train_rms); 71 72 % Print out test RMS error 73 actual_prices = test.y; 74 predicted_prices = theta\'*test.X; 75 test_rms=sqrt(mean((predicted_prices - actual_prices).^2)); 76 fprintf(\'RMS testing error: %f\\n\', test_rms); 77 78 79 % Plot predictions on test data. 80 plot_prices=true; 81 if (plot_prices) 82 [actual_prices,I] = sort(actual_prices); %从小到大排序价格 83 predicted_prices=predicted_prices(I); 84 plot(actual_prices, \'rx\'); 85 hold on; 86 plot(predicted_prices,\'bx\'); 87 legend(\'Actual Price\', \'Predicted Price\'); 88 xlabel(\'House #\'); 89 ylabel(\'House price ($1000s)\'); 90 end</span>
3.为了保证运算的准确性我们需要对cost function函数进行运行,得到最小成本函数(本次练习时我添加的代码)
1 % Step 1 : 计算成本函数 2 for i = 1:m 3 f = f + (theta\' * X(:,i) - y(i))^2; 4 end 5 f = 1/2*f; 6 7 % Step 2:计算gradient并储存在g中 8 9 for j = 1:n 10 for i = 1:m 11 g(j) = g(j) + X(j,i)*(theta\' * X(:,i) - y(i)); 12 end
以上是关于机器学习--线性回归的实践的主要内容,如果未能解决你的问题,请参考以下文章
机器学习入门实践——线性回归&非线性回归&mnist手写体识别