如何在 python 中使用 .predict() 方法进行线性回归?
Posted
技术标签:
【中文标题】如何在 python 中使用 .predict() 方法进行线性回归?【英文标题】:How to use .predict() method in python for linear regression? 【发布时间】:2021-07-23 00:58:36 【问题描述】:我从我的数据框中估计了多元回归模型。我有三个独立变量:月份(1 到 36)、价格和广告日。
我想做出预测,改变条件:
-未来 10 个月(37 到 47)的预测值,价格 = 85,广告日 = 4
我估计了我的模型并尝试了:
Time1= np.arange(37,48)
Price1=85
Ads1=4
Lm.predict([Time1,Price1,Ads1])
但它不起作用
谢谢
【问题讨论】:
【参考方案1】:你需要一个二维数组
Lm.predict([[Time1,Price1,Ads1]])
【讨论】:
【参考方案2】:假设您的模型是在没有任何嵌套数组的二维数组上训练的,问题是:
-
您要预测的输入不是二维的
变量
Time1
本身就是一个数组,因此,您创建了一个嵌套数组:[Time1,Price1,Ads1]
您当前的预测调用如下:
Time1 = np.arange(37,48)
Price1=85
Ads1=4
print([Time1,Price1,Ads1])
看起来像:
[array([37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]), 85, 4]
您可以像这样将其转换为所需的格式:
import numpy as np
print(np.concatenate([Time1, [Price1, Ads1]]).reshape(1,-1))
看起来像:
array([[37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 85, 4]])
【讨论】:
【参考方案3】:首先使用过去观察的训练数据训练模型。在您的情况下,训练数据构成每个观察的 3 个三个自变量和 1 个因变量。
一旦训练出体面的模型(使用超参数优化),您就可以使用它来进行预测。
示例代码(内嵌文档)
import numpy as np
from sklearn.linear_model import LinearRegression
# sample dummy data
# independent variables
time = np.arange(1,36)
price = np.random.randint(1,100,35)
ads = np.random.randint(1,10,35)
# dependent variable
y = np.random.randn(35)
# Reshape it into 35X3 where each row is an observation
train_X = np.vstack([time, price, ads]).T
# Fit the model
model = LinearRegression().fit(train_X, y)
# Sample observations for which
# forecast of dependent variable has to be made
time1 = np.arange(37, 47)
price1 = np.array([85]*len(time1))
ads1 = np.array([4]*len(time1))
# Reshape such that each row is an observation
test_X = np.vstack([time1, price1, ads1]).T
# make the predictions
print (model.predict(test_X))'
输出:
array([0.22189608, 0.2269302 , 0.23196433, 0.23699845, 0.24203257,
0.24706669, 0.25210081, 0.25713494, 0.26216906, 0.26720318])
【讨论】:
以上是关于如何在 python 中使用 .predict() 方法进行线性回归?的主要内容,如果未能解决你的问题,请参考以下文章
python - 如何从python中sklearn中的cross_val_predict获取排序的概率和名称
如何将文件或图像作为 Keras 模型中的参数提供给 model.predict?
predict_proba 不适用于我的高斯混合模型(sklearn,python)