多元线性回归模型的等高线图
Posted
技术标签:
【中文标题】多元线性回归模型的等高线图【英文标题】:Contour plot for multi linear regression model 【发布时间】:2021-12-20 00:03:08 【问题描述】:我必须使用以下变量获取等高线图以获得一系列最佳值:
X axis = SiO2/Al2O3
Y axis = Precursor/Aggregate
Z axis = Compressive Strength
我的代码如下
import numpy as np
import matplotlib as mlt
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
dataset = pd.read_csv('Data.csv')
X = dataset.iloc[:, :-1].values
y = dataset.iloc[:, -1].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)
regressor = LinearRegression()
regressor.fit(X_train, y_train)
y_predict = regressor.predict(X_test)
feature_x = X_test[:, 1]
feature_y = X_test[:, 3]
[X, Y] = np.meshgrid(feature_x, feature_y)
Z = y_predict
ax.contourf(X, Y, Z)
ax.set_title('Filled Contour Plot')
ax.set_xlabel('SiO2/Al2O3')
ax.set_ylabel('Precursor/Aggregate')
plt.show()
但它给出了这个错误
TypeError: Input z must be 2D, not 1D
我想我在 Z 轴输入中犯了一个错误。
数据可用at this link。
预期输出:
【问题讨论】:
也许你需要Z = y_predict.reshape(X.shape)
?
【参考方案1】:
您的代码将不起作用,您需要为您的预测值创建一个网格,首先我们读取您的数据并进行拟合:
dataset = pd.read_csv('Data.csv')
X = dataset.iloc[:, :-1].values
y = dataset.iloc[:, -1].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)
regressor = LinearRegression()
regressor.fit(X_train, y_train)
然后你需要为你感兴趣的特征创建一个网格:
feature_x = np.linspace(X_test[:, 1].min(),X_test[:, 1].max(),100)
feature_y = np.linspace(X_test[:, 3].min(),X_test[:, 3].max(),100)
网格:
dim1, dim2 = np.meshgrid(feature_x, feature_y)
现在,您的模型还有 6 个其他预测变量需要您提供以进行拟合。一种方法是将这些其他变量保持在它们的平均值,然后我们在网格中插入:
mesh_df = np.array([X_test.mean(axis=0) for i in range(dim1.size)])
mesh_df[:,1] = dim1.ravel()
mesh_df[:,2] = dim2.ravel()
现在预测、重塑和绘制:
Z = regressor.predict(mesh_df).reshape(dim1.shape)
fig, ax = plt.subplots()
ax.contourf(dim1, dim2, Z)
ax.set_title('Filled Contour Plot')
ax.set_xlabel('SiO2/Al2O3')
ax.set_ylabel('Precursor/Aggregate')
plt.show()
看起来像这样,因为您使用的是线性回归,值将随变量线性增加或减少:
【讨论】:
以上是关于多元线性回归模型的等高线图的主要内容,如果未能解决你的问题,请参考以下文章