statsmodel OLS 和 scikit-learn 线性回归之间的区别
Posted
技术标签:
【中文标题】statsmodel OLS 和 scikit-learn 线性回归之间的区别【英文标题】:Difference between statsmodel OLS and scikit-learn linear regression 【发布时间】:2022-01-01 23:43:44 【问题描述】:我尝试用 iris 数据集练习线性回归模型。
from sklearn import datasets
import seaborn as sns
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from sklearn.linear_model import LinearRegression
# load iris data
train = sns.load_dataset('iris')
train
# one-hot-encoding
species_encoded = pd.get_dummies(train["species"], prefix = "speceis")
species_encoded
train = pd.concat([train, species_encoded], axis = 1)
train
# Split by feature and target
feature = ["sepal_length", "petal_length", "speceis_setosa", "speceis_versicolor", "speceis_virginica"]
target = ["petal_width"]
X_train = train[feature]
y_train = train[target]
案例 1:统计模型
# model
X_train_constant = sm.add_constant(X_train)
model = sm.OLS(y_train, X_train_constant).fit()
print("const : :.6f".format(model.params[0]))
print(model.params[1:])
result :
const : 0.253251
sepal_length -0.001693
petal_length 0.231921
speceis_setosa -0.337843
speceis_versicolor 0.094816
speceis_virginica 0.496278
案例 2:scikit-learn
# model
model = LinearRegression()
model.fit(X_train, y_train)
print("const : :.6f".format(model.intercept_[0]))
print(pd.Series(model.coef_[0], model.feature_names_in_))
result :
const : 0.337668
sepal_length -0.001693
petal_length 0.231921
speceis_setosa -0.422260
speceis_versicolor 0.010399
speceis_virginica 0.411861
为什么statsmodels和sklearn的结果不一样?
另外,除了全部或部分 one-hot-encoded 特征之外,两个模型的结果是相同的。
【问题讨论】:
【参考方案1】:您将一整套单热编码假人作为回归量包含在内,这导致线性组合等于常数,因此您具有完美的多重共线性:您的协方差矩阵是奇异的,您不能取其逆矩阵。
在幕后statsmodels
和sklearn
都依赖于 Moore-Penrose 伪逆并且可以很好地反转奇异矩阵,问题是在奇异协方差矩阵情况下获得的系数在任何物理意义上都没有任何意义.包之间的实现略有不同(sklearn
依赖于scipy.stats.lstsq
,statsmodels
有一些自定义过程statsmodels.tools.pinv_extended
,基本上是numpy.linalg.svd
,变化很小),所以在一天结束时它们都显示«废话»(因为无法获得有意义的系数),这只是显示什么样的«废话»的设计选择。
如果你取 one-hot 编码假人的系数之和,你可以看到 statsmodels
等于常数,sklearn
等于 0,而常数不同于 @987654330 @ 持续的。对完美多重共线性不“负责任”的变量的系数不受影响。
【讨论】:
以上是关于statsmodel OLS 和 scikit-learn 线性回归之间的区别的主要内容,如果未能解决你的问题,请参考以下文章
我们如何计算 statsmodels OLS 中的截距和斜率?
为啥当我使用 statsmodels 进行 OLS 和使用 scikit 进行 PooledOLS 时得到相同的结果?
AttributeError:模块“statsmodels.formula.api”没有属性“OLS”
使用 statsmodels.formula.api 中的 ols - 如何删除常数项?