使用 Matplotlib 在 3d 中绘制线性模型

Posted

技术标签:

【中文标题】使用 Matplotlib 在 3d 中绘制线性模型【英文标题】:Plot linear model in 3d with Matplotlib 【发布时间】:2014-12-13 10:29:48 【问题描述】:

我正在尝试创建适合数据集的线性模型的 3d 图。我能够在 R 中相对轻松地做到这一点,但我真的很难在 Python 中做到这一点。这是我在 R 中所做的:

这是我在 Python 中所做的:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.formula.api as sm

csv = pd.read_csv('http://www-bcf.usc.edu/~gareth/ISL/Advertising.csv', index_col=0)
model = sm.ols(formula='Sales ~ TV + Radio', data = csv)
fit = model.fit()

fit.summary()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(csv['TV'], csv['Radio'], csv['Sales'], c='r', marker='o')

xx, yy = np.meshgrid(csv['TV'], csv['Radio'])

# Not what I expected :(
# ax.plot_surface(xx, yy, fit.fittedvalues)

ax.set_xlabel('TV')
ax.set_ylabel('Radio')
ax.set_zlabel('Sales')

plt.show()

我做错了什么,我应该怎么做?

谢谢。

【问题讨论】:

您介意提供您的代码吗? 【参考方案1】:

知道了!

我在 cmets 中谈到 mdurant 回答的问题是,表面没有像 Combining scatter plot with surface plot 那样绘制成漂亮的方形图案。

我意识到问题出在我的meshgrid 上,因此我更正了两个范围(xy)并为np.arange 使用了比例步长。

这让我可以使用 mdurant 的答案提供的代码,而且效果很好!

结果如下:

这是代码:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.formula.api as sm
from matplotlib import cm

csv = pd.read_csv('http://www-bcf.usc.edu/~gareth/ISL/Advertising.csv', index_col=0)
model = sm.ols(formula='Sales ~ TV + Radio', data = csv)
fit = model.fit()

fit.summary()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

x_surf = np.arange(0, 350, 20)                # generate a mesh
y_surf = np.arange(0, 60, 4)
x_surf, y_surf = np.meshgrid(x_surf, y_surf)

exog = pd.core.frame.DataFrame('TV': x_surf.ravel(), 'Radio': y_surf.ravel())
out = fit.predict(exog = exog)
ax.plot_surface(x_surf, y_surf,
                out.reshape(x_surf.shape),
                rstride=1,
                cstride=1,
                color='None',
                alpha = 0.4)

ax.scatter(csv['TV'], csv['Radio'], csv['Sales'],
           c='blue',
           marker='o',
           alpha=1)

ax.set_xlabel('TV')
ax.set_ylabel('Radio')
ax.set_zlabel('Sales')

plt.show()

【讨论】:

【参考方案2】:

您假设 plot_surface 想要一个坐标网格网格来使用是正确的,但是 predict 想要一个像您安装的数据结构(“exog”)。

exog = pd.core.frame.DataFrame('TV':xx.ravel(),'Radio':yy.ravel())
out = fit.predict(exog=exog)
ax.plot_surface(xx, yy, out.reshape(xx.shape), color='None')

【讨论】:

这种工作,但生成的平面绘制得很糟糕。我很困惑,似乎没有更好、更简单的方法来做到这一点。 严重怎么办?您可以为 plot_surface 提供许多参数。但是,如果您想要花哨的 3D,您可能需要研究 mayavi。 这个imgur.com/A5w55U6 是我得到的(我删除了color='None')。我正在寻找这样的东西:***.com/questions/15229896/…。顺便说一句,感谢您的帮助。 知道了!请参阅下面的答案。我将您的标记为正确。再次感谢。

以上是关于使用 Matplotlib 在 3d 中绘制线性模型的主要内容,如果未能解决你的问题,请参考以下文章

Python 使用 matplotlib绘制3D图形

使用 Python matplotlib 绘制 3d 矢量

如何使用 matplotlib 在 python 中绘制 3D 密度图

python Matplotlib绘制线性图

在 Matplotlib 中绘制一个 3d 立方体、一个球体和一个向量

Matplotlib - 同时绘制 3D 平面和点