Matplotlib:pandas MultiIndex DataFrame 的自定义代码
Posted
技术标签:
【中文标题】Matplotlib:pandas MultiIndex DataFrame 的自定义代码【英文标题】:Matplotlib: custom ticker for pandas MultiIndex DataFrame 【发布时间】:2019-01-25 07:54:33 【问题描述】:我有一个大的pandas MultiIndex DataFrame,我想绘制它。一个最小的示例如下所示:
import pandas as pd
years = range(2015, 2018)
fields = range(4)
days = range(4)
bands = ['R', 'G', 'B']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
df = pd.DataFrame(0, index=index, columns=columns)
df.loc[(2015,), (0,)] = 1
df.loc[(2016,), (1,)] = 1
df.loc[(2017,), (2,)] = 1
如果我使用plt.spy
绘制此图,我会得到:
但是,刻度位置和标签不太理想。我希望刻度完全忽略 MultiIndex 的第二级。使用IndexLocator
和IndexFormatter
,我可以做到以下几点:
from matplotlib.ticker import IndexFormatter, IndexLocator
import matplotlib.pyplot as plt
ax = plt.gca()
plt.spy(df)
xbase = len(bands)
xoffset = xbase / 2
xlabels = df.columns.get_level_values('day')
ax.xaxis.set_major_locator(IndexLocator(base=xbase, offset=xoffset))
ax.xaxis.set_major_formatter(IndexFormatter(xlabels))
plt.xlabel('Day')
ax.xaxis.tick_bottom()
ybase = len(fields)
yoffset = ybase / 2
ylabels = df.index.get_level_values('year')
ax.yaxis.set_major_locator(IndexLocator(base=ybase, offset=yoffset))
ax.yaxis.set_major_formatter(IndexFormatter(ylabels))
plt.ylabel('Year')
plt.show()
这正是我想要的:
但这就是问题所在。我的实际 DataFrame 有 15 年、4,000 个字段、365 天和 7 个波段。如果我真的每天都贴标签,标签将难以辨认。我可以每 50 天放置一个刻度,但我希望刻度是动态的,这样当我放大时,刻度变得更加细粒度。基本上我正在寻找的是一个自定义的MultiIndexLocator
,它结合了IndexLocator
的位置和MaxNLocator
的活力。
奖励:我的数据非常好,因为每年总是有相同数量的字段,每天都有相同数量的波段。但如果不是这样呢?我很乐意为matplotlib
贡献一个通用的MultiIndexLocator
和MultiIndexFormatter
,它适用于任何MultiIndex DataFrame。
【问题讨论】:
【参考方案1】:Matplotlib 不知道数据帧或 MultiIndex。它只是绘制您提供的数据。 IE。你得到的结果就像你在绘制 numpy 数据数组一样,spy(df.values)
。
所以我建议首先正确设置图像的范围,以便您可以使用数字代码。那么MaxNLocator
应该可以正常工作,除非您没有放大太多。
import numpy as np
import pandas as pd
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
plt.rcParams['axes.formatter.useoffset'] = False
years = range(2000, 2018)
fields = range(9) #17
days = range(120) #365
bands = ['R', 'G', 'B', 'A']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
data = np.random.rand(len(years)*len(fields),len(days)*len(bands))
x,y = np.meshgrid(np.arange(data.shape[1]),np.arange(data.shape[0]))
data += 2*((y//len(fields)+x//len(bands)) % 2)
df = pd.DataFrame(data, index=index, columns=columns)
############
# Plotting
############
xbase = len(bands)
xlabels = df.columns.get_level_values('day')
ybase = len(fields)
ylabels = df.index.get_level_values('year')
extent = [xlabels.min()-np.diff(np.unique(xlabels))[0]/2.,
xlabels.max()+np.diff(np.unique(xlabels))[0]/2.,
ylabels.min()-np.diff(np.unique(ylabels))[0]/2.,
ylabels.max()+np.diff(np.unique(ylabels))[0]/2.,]
fig, ax = plt.subplots()
ax.imshow(df.values, extent=extent, aspect="auto")
ax.set_ylabel('Year')
ax.set_xlabel('Day')
ax.xaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
ax.yaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
plt.show()
【讨论】:
这看起来正是我想要的!这种方法是否只适用于plt.imshow
?我需要手动二值化我的数据吗?我实际上是在尝试可视化 NaN 与非 NaN 像素。
啊,我看到plt.spy
实际上只是plt.imshow
的包装。这应该有效。让我看看我是否可以将我的实际代码转换为使用它。
我成功了!很遗憾integers=True
在我放大得太远时不适用,但我找不到解决这个问题的方法。感谢您的帮助!
它确实需要imshow
,但是(正如您现在可能已经发现的那样)您可以通过使用适当的颜色图来调整imshow 使其看起来与spy
完全相同。如果您需要更多地控制要显示的刻度,您可以注册限制更改事件的回调并设置自定义定位器。但我不确定如果不是整数,它应该打勾。以上是关于Matplotlib:pandas MultiIndex DataFrame 的自定义代码的主要内容,如果未能解决你的问题,请参考以下文章
如何获取“matplotlib”、“numpy”、“scipy”、“pandas”等的存根文件?