简单线性回归预测实现
Posted Erick - LONG
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了简单线性回归预测实现相关的知识,希望对你有一定的参考价值。
import numpy as np from .metrics import r2_score class SimpleLinearRegression: def __init__(self): """初始化Simple Linear Regression模型""" self.a_ = None self.b_ = None def fit(self, x_train, y_train): """根据训练数据集x_train, y_train训练Simple Linear Regression模型""" assert x_train.ndim == 1, "Simple Linear Regressor can only solve single feature training data." assert len(x_train) == len(y_train), "the size of x_train must be equal to the size of y_train" x_mean = np.mean(x_train) y_mean = np.mean(y_train) self.a_ = (x_train - x_mean).dot(y_train - y_mean) / (x_train - x_mean).dot(x_train - x_mean)#向量化 点乘运算 self.b_ = y_mean - self.a_ * x_mean return self def predict(self, x_predict): """给定待预测数据集x_predict,返回表示x_predict的结果向量""" assert x_predict.ndim == 1, "Simple Linear Regressor can only solve single feature training data." assert self.a_ is not None and self.b_ is not None, "must fit before predict!" return np.array([self._predict(x) for x in x_predict]) def _predict(self, x_single): """给定单个待预测数据x,返回x的预测结果值""" return self.a_ * x_single + self.b_ def score(self, x_test, y_test): """根据测试数据集 x_test 和 y_test 确定当前模型的准确度""" y_predict = self.predict(x_test) return r2_score(y_test, y_predict) def __repr__(self): return "SimpleLinearRegression()"
衡量线性回归模型误差的三种方式
def mean_squared_error(y_true, y_predict): """计算y_true和y_predict之间的MSE""" assert len(y_true) == len(y_predict), "the size of y_true must be equal to the size of y_predict" return np.sum((y_true - y_predict)**2) / len(y_true) def root_mean_squared_error(y_true, y_predict): """计算y_true和y_predict之间的RMSE""" return sqrt(mean_squared_error(y_true, y_predict)) def mean_absolute_error(y_true, y_predict): """计算y_true和y_predict之间的RMSE""" assert len(y_true) == len(y_predict), "the size of y_true must be equal to the size of y_predict" return np.sum(np.absolute(y_true - y_predict)) / len(y_true)
计算最终模型预测准确度R方
def r2_score(y_true, y_predict): """计算y_true和y_predict之间的R Square""" return 1 - mean_squared_error(y_true, y_predict)/np.var(y_true)
以上是关于简单线性回归预测实现的主要内容,如果未能解决你的问题,请参考以下文章