如何添加一条最适合散点图的线
Posted
技术标签:
【中文标题】如何添加一条最适合散点图的线【英文标题】:How to add a line of best fit to scatter plot 【发布时间】:2016-09-11 01:50:39 【问题描述】:我目前正在使用 Pandas 和 matplotlib 来执行一些数据可视化,我想在我的散点图中添加一条最适合的线。
这是我的代码:
import matplotlib
import matplotlib.pyplot as plt
import pandas as panda
import numpy as np
def PCA_scatter(filename):
matplotlib.style.use('ggplot')
data = panda.read_csv(filename)
data_reduced = data[['2005', '2015']]
data_reduced.plot(kind='scatter', x='2005', y='2015')
plt.show()
PCA_scatter('file.csv')
我该怎么做?
【问题讨论】:
这能回答你的问题吗? Code for best fit straight line of a scatter plot in python 【参考方案1】:您可以使用np.polyfit()
和np.poly1d()
。使用相同的x
值估计一次多项式,并添加到由.scatter()
绘图创建的ax
对象。举个例子:
import numpy as np
2005 2015
0 18882 21979
1 1161 1044
2 482 558
3 2105 2471
4 427 1467
5 2688 2964
6 1806 1865
7 711 738
8 928 1096
9 1084 1309
10 854 901
11 827 1210
12 5034 6253
估计一次多项式:
z = np.polyfit(x=df.loc[:, 2005], y=df.loc[:, 2015], deg=1)
p = np.poly1d(z)
df['trendline'] = p(df.loc[:, 2005])
2005 2015 trendline
0 18882 21979 21989.829486
1 1161 1044 1418.214712
2 482 558 629.990208
3 2105 2471 2514.067336
4 427 1467 566.142863
5 2688 2964 3190.849200
6 1806 1865 2166.969948
7 711 738 895.827339
8 928 1096 1147.734139
9 1084 1309 1328.828428
10 854 901 1061.830437
11 827 1210 1030.487195
12 5034 6253 5914.228708
和情节:
ax = df.plot.scatter(x=2005, y=2015)
df.set_index(2005, inplace=True)
df.trendline.sort_index(ascending=False).plot(ax=ax)
plt.gca().invert_xaxis()
获得:
还提供了线方程:
'y=0:.2f x + 1:.2f'.format(z[0],z[1])
y=1.16 x + 70.46
【讨论】:
trendline.plot(ax=ax)
行给了我一个无效的语法错误
z = np.polyfit(x=data_reduced[['2005']], y=data_reduced[['2015']], 1)
行给了我一个“位置参数跟随关键字参数”错误
对不起,degree
需要在=1
之前添加deg
,见更新。
TypeError: 对于行 z = np.polyfit(x=data_reduced[['2005']], y=data_reduced[['2015']], deg=1)
,x 的预期一维向量。这是我的数据或代码的问题吗?
需要使用.loc[]
,所以单列变成pd.Series
。使用[[]]
选择会保留一列作为DataFrame
,因此会出现维度警告。更新,同样适用于下一行。不好意思,时间不早了……【参考方案2】:
另一个选项(使用np.linalg.lstsq
):
# generate some fake data
N = 50
x = np.random.randn(N, 1)
y = x*2.2 + np.random.randn(N, 1)*0.4 - 1.8
plt.axhline(0, color='r', zorder=-1)
plt.axvline(0, color='r', zorder=-1)
plt.scatter(x, y)
# fit least-squares with an intercept
w = np.linalg.lstsq(np.hstack((x, np.ones((N,1)))), y)[0]
xx = np.linspace(*plt.gca().get_xlim()).T
# plot best-fit line
plt.plot(xx, w[0]*xx + w[1], '-k')
【讨论】:
【参考方案3】:您可以使用Seaborn 一口气完成所有工作和情节。
import pandas as pd
import seaborn as sns
data_reduced= pd.read_csv('fake.txt',sep='\s+')
sns.regplot(data_reduced['2005'],data_reduced['2015'])
【讨论】:
但是我想用matplotlib! :( 这个解决方案多么简单,真是太棒了!非常感谢! 如果您想在循环和创建多个图表时一次查看一个图表,您仍然需要 matplotlib 的 plt.show() 【参考方案4】:这涵盖了plotly
方法
#load the libraries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
# create the data
N = 50
x = pd.Series(np.random.randn(N))
y = x*2.2 - 1.8
# plot the data as a scatter plot
fig = px.scatter(x=x, y=y)
# fit a linear model
m, c = fit_line(x = x,
y = y)
# add the linear fit on top
fig.add_trace(
go.Scatter(
x=x,
y=m*x + c,
mode="lines",
line=go.scatter.Line(color="red"),
showlegend=False)
)
# optionally you can show the slop and the intercept
mid_point = x.mean()
fig.update_layout(
showlegend=False,
annotations=[
go.layout.Annotation(
x=mid_point,
y=m*mid_point + c,
xref="x",
yref="y",
text=str(round(m, 2))+'x+'+str(round(c, 2)) ,
)
]
)
fig.show()
fit_line
在哪里
def fit_line(x, y):
# given one dimensional x and y vectors - return x and y for fitting a line on top of the regression
# inspired by the numpy manual - https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html
x = x.to_numpy() # convert into numpy arrays
y = y.to_numpy() # convert into numpy arrays
A = np.vstack([x, np.ones(len(x))]).T # sent the design matrix using the intercepts
m, c = np.linalg.lstsq(A, y, rcond=None)[0]
return m, c
【讨论】:
【参考方案5】:上面的最佳答案是使用 seaborn。 补充一点,如果你用循环创建许多图,你仍然可以使用 matplotlib
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
data_reduced= pd.read_csv('fake.txt',sep='\s+')
for x in data_reduced.columns:
sns.regplot(data_reduced[x],data_reduced['2015'])
plt.show()
plt.show() 将暂停执行,以便您一次查看一个图
【讨论】:
【参考方案6】:只是添加到(更新罗伯特卡尔霍恩的答案)。如果您不指定 x,y,您现在将在新版本的 pandas 上收到未来警告。
FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
所以,如下。
import pandas as pd
import seaborn as sns
data_reduced= pd.read_csv('fake.txt',sep='\s+')
sns.regplot(x=data_reduced['2005'],y=data_reduced['2015'])
【讨论】:
以上是关于如何添加一条最适合散点图的线的主要内容,如果未能解决你的问题,请参考以下文章