python实现ols回归

Posted ftcyllb

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python实现ols回归相关的知识,希望对你有一定的参考价值。

文章目录

一、回归方法结构

  1. 回归可选参数:稳健标准误、向后逐步回归
  2. 异常检测:异方差(BP检验、怀特检验),多重共线性(VIF检验)

二、代码结构

类型名称功能
funcmain启动函数
classMyRegResult封装需要存储的回归结果
classExogEmpty自变量空时抛出异常
funcmy_olsols回归主干函数
funcback_ols向后逐步回归
funcstan_calcu计算标准化回归系数
funcarch_test异方差检测
funcvif_test检测多重共线性
funcsummary打印并存储回归结果
functxt_excel将ols默认生成的描述转为excel

三、代码详解

3.1 导包

import openpyxl
import pandas as pd
import statsmodels.formula.api as smf
from statsmodels.stats.diagnostic import het_breuschpagan, het_white
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.regression.linear_model import RegressionResultsWrapper

3.2 main:启动函数

if __name__ == '__main__':
    excel = pd.read_excel('regression_test_2.xlsx', 'Sheet1')
    robust = True                   # 稳健标准误
    back_reg = True                 # 是否使用向后逐步回归
	
	# 此处设置为0.6是方便测试(因为数据集都不显著,自变量会被清空,可以自己设为0.05试试)
    back_reg_p = 0.6               	# 向后逐步回归的p值,一般是0.05
	
	file_name = None                # 指定文件名,可为None使用默认名字
    y = 1                           # int,因变量列id
    x = list(range(2, 7))           # int的列表,自变量列id
	
    my_reg_res = my_ols(excel, y, x, is_robust=robust, back=back_reg, back_p=back_reg_p)  # 回归
    summary(my_reg_res, is_robust=robust, name=file_name)       # 打印并存储回归结果

3.3 MyRegResult:封装需要存储的回归结果

class MyRegResult(object):
    def __init__(self,
				# smf.ols自动生成的结果
                 results: RegressionResultsWrapper = None,
				 # 标准化回归系数,smf.ols不自动生成,需要手算
                 stan_param: pd.DataFrame = None,
				 # 异方差检测
                 arch: pd.DataFrame = None,
				 # 多重共线性检测
                 vif: pd.Series = None
                 ):
        self.results = results
        self.stan_param = stan_param
        self.arch = arch
        self.vif = vif
        pass

    def __str__(self):
		# 打印时自动调用
        print(self.results.summary(), self.stan_param, self.arch, self.vif, sep='\\n\\n')
        return ''

3.4 ExogEmpty:自变量空时抛出异常

在向后逐步回归时,会删去自变量,当自变量被删空时,需要停止程序

class ExogEmpty(Exception):
    def __str__(self):
        return '自变量空'

3.5 my_ols:ols回归主干

# 传入参数的含义在main中有
# return:MyRegResult对象,回归结果
def my_ols(data: pd.DataFrame, y_id=None, x_id=None, is_robust=False, back=False, back_p=0.05):
    data = data.copy()  # 完全拷贝,避免误修改外部变量
    reg_res = MyRegResult()		# 封装回归结果,再一个一个赋值
	
	# 根据y_id和x_id拼接formula
    var_name = data.columns
    y_str = str(var_name[y_id])
    x_str = '+'.join([str(i) for i in var_name[x_id]])
    formula = y_str + '~' + x_str
	
    model = smf.ols(formula, data)
	# 包含虚拟变量的自变量(ols获取的model会自动虚拟化)
	# 后续需要用到虚拟化后的变量
    exog = pd.DataFrame(model.exog, columns=model.exog_names)	# 自变量
	
	# 获取回归结果
    if back:  # 向后逐步回归
        endog = pd.Series(data[y_str], name=y_str)	# 因变量
		# 除了回归结果,还会返回一个修改后的自变量,用于后续计算
        res, exog = back_ols(endog, exog, back_p, is_robust)
    else:  # 直接回归
        res = model.fit(cov_type='HC1') if is_robust else model.fit()
    reg_res.results = res
	
    # 计算标准化回归系数
    y_std = data[y_str].std()	# 因变量方差
    reg_res.stan_param = stan_calcu(res.params, y_std, exog)
	
    # 异方差检验
    reg_res.arch = arch_test(res.resid, exog)
	
    # 检测多重共线性
    reg_res.vif = vif_test(exog)
    return reg_res

3.6 stan_calcu:计算标准化回归系数

公式:标准化回归系数 = 未标准化回归系数 * 该自变量的标准差 / 因变量的标准差

**与非标准化回归系数的区别:**标准化回归系数是在对自变量和因变量同时进行标准化处理后所得到的回归系数,数据经过标准化处理后消除了量纲、数量级等差异的影响,使得不同变量之间具有可比性

# param:非标准化回归系数
# y_std:因变量方差
# exog:自变量
# return:标准化回归系数
def stan_calcu(param: pd.Series, y_std, exog: pd.DataFrame):
    stan = param.rename_axis('stan_params', axis='index')	# 导入非标准化参数
    stan = stan.drop('Intercept')	# 去除截距项
    for i, v in stan.items():
        stan[i] = v * exog[i].std() / y_std		# 计算标准化参数
    return stan

3.7 arch_test:异方差检验

异方差检测H0:不存在异方差
当p<0.05时,拒绝原假设
bp不可检测平方项和交互项,bp检验是怀特检验的一种特例
bp可检测方程:(无平方项、交互项)

# resid:每个样本的残差
# exog:自变量
def arch_test(resid: pd.Series, exog: pd.DataFrame):
    arch = pd.DataFrame(index=['stati', 'p', '注释'], columns=['bp', 'white']).rename_axis('ARCH', axis='columns')
    arch.loc['stati':'p', 'bp'] = het_breuschpagan(resid, exog)[:2]		# bp检测
    arch.loc['stati':'p', 'white'] = het_white(resid, exog)[:2]		# 怀特检测
    arch.loc['注释', 'bp'] = '异方差检测H0:不存在异方差,bp不可检测平方项和交互项'
    return arch

3.8 vif_test:检测多重共线性

# exog:自变量
def vif_test(exog: pd.DataFrame):
	# 计算vif
    vif = pd.Series([variance_inflation_factor(exog, i) for i in range(exog.shape[1])],
                    index=exog.columns).rename_axis('VIF', axis='index')
    vif = vif.drop('Intercept')		# 去除截距项
    vif = vif.sort_values(ascending=False)		# 降序查看高共线性变量
    return vif

3.9 back_ols:向后逐步回归

思想:先将所有变量均放入模型,不断迭代将最没有解释力的那个自变量剔除,直到没有自变量符合剔除的条件

# 向后逐步回归:返回回归结果与修改后的自变量(包含截距项)
def back_ols(endog: pd.Series, exog: pd.DataFrame, back_p, is_robust):
    # 为了避免完全多重共线性的影响,引入虚拟变量的个数一般是分类数减1(前面ols生成的model的exog已经自动减1了)
    endog, exog = endog.copy(), exog.copy()
    # smf.ols会自动加上截距项,此处删去避免重复
    intercept = exog.pop('Intercept')
    # 循环检测p值,直到全小于back_p
    while True:
        # 共线性检验、向后逐步回归都会删除自变量,检测自变量是否为空
        if exog.empty:
            raise ExogEmpty()
        data = pd.concat([exog, endog], axis=1)
        formula = str(endog.name) + '~' + '+'.join(exog.columns.values)
        model = smf.ols(formula, data)
        res = model.fit(cov_type='HC1') if is_robust else model.fit()
        p = res.pvalues.drop('Intercept')
        p_max_id = p.idxmax()
        if p[p_max_id] > back_p:
            exog = exog.drop(p_max_id, axis=1)
        else:	# 全部小于预设p值,退出循环
            break
    # 在逐步回归后,仍需进行异方差检验和多重共线性检验,需要加回截距项
    exog.insert(0, 'Intercept', intercept)
    return res, exog

3.10 summary:打印并存储回归结果

# reg_res:回归结果
def summary(reg_res: MyRegResult, is_robust=False, name=None):
    print(reg_res)
    if name:    # 指定文件名
        f = get_file_path(name)
    else:   # 默认文件名
        f = get_file_path('reg_robust.xlsx' if is_robust else 'reg.xlsx', 'out')
    # 写入summary
    txt_to_excel(reg_res.results.summary().as_csv(), f, sheet_name='summary', split_flag=',')
    with pd.ExcelWriter(f, mode="a", engine="openpyxl") as writer:
        # 写入标准化回归系数
        reg_res.stan_param.to_excel(writer, 'stan_param')
        # 写入arch
        reg_res.arch.to_excel(writer, 'ARCH')
        # 写入vif
        reg_res.vif.to_excel(writer, 'VIF')

3.11 txt_excel:将ols默认生成的描述转为excel

def txt_to_excel(txt, excel, sheet_name='Sheet1', split_flag=' '):
    wb = openpyxl.Workbook()
    wb.remove(wb.active)
    ws = wb.create_sheet(sheet_name)
    txt_line = txt.split('\\n')
    for i, line in enumerate(txt_line):
        txt_cell = line.split(split_flag)
        for j, cell in enumerate(txt_cell):
            ws.cell(row=i+1, column=j+1).value = cell.strip()
    wb.save(excel)

四、附件

链接:https://pan.baidu.com/s/19Is8uT_47ZzhJKJm5jEuyA?pwd=0szv
提取码:0szv

以上是关于python实现ols回归的主要内容,如果未能解决你的问题,请参考以下文章

5.2 多元线性回归完成广告投放销售额预测——python实战

数学建模学习:岭回归和lasso回归

超详细多元线性回归模型statsmodels_ols

超详细多元线性回归模型statsmodels_ols

线性回归-OLS法

查找 p 值和 z 统计量以及 OLS 线性回归