使用散点数据集在 MatPlotLib 中生成热图
Posted
技术标签:
【中文标题】使用散点数据集在 MatPlotLib 中生成热图【英文标题】:Generate a heatmap in MatPlotLib using a scatter data set 【发布时间】:2011-01-23 02:13:37 【问题描述】:我有一组 X、Y 数据点(大约 10k),很容易绘制为散点图,但我想将其表示为热图。
我查看了 MatPlotLib 中的示例,它们似乎都已经从热图单元格值开始生成图像。
有没有一种方法可以将一堆不同的 x,y 转换为热图(x,y 频率较高的区域会“更暖”)?
【问题讨论】:
同样相关:Efficient method of calculating density of irregularly spaced points 【参考方案1】:如果不想要六边形,可以使用numpy的histogram2d
函数:
import numpy as np
import numpy.random
import matplotlib.pyplot as plt
# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)
heatmap, xedges, yedges = np.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.clf()
plt.imshow(heatmap.T, extent=extent, origin='lower')
plt.show()
这会生成一个 50x50 的热图。如果你想要,比如说 512x384,你可以把bins=(512, 384)
加入到histogram2d
的调用中。
示例:
【讨论】:
我并不是想成为一个白痴,但是您实际上如何将此输出到 PNG/PDF 文件而不是仅在交互式 IPython 会话中显示?我试图将其作为某种普通的axes
实例,我可以在其中添加标题、轴标签等,然后像对任何其他典型的 matplotlib 图一样执行普通的 savefig()
。跨度>
@gotgenes:plt.savefig('filename.png')
不起作用吗?如果要获取坐标轴实例,请使用 Matplotlib 的面向对象接口:fig = plt.figure()
ax = fig.gca()
ax.imshow(...)
fig.savefig(...)
确实,谢谢!我想我不完全理解imshow()
与scatter()
属于同一类函数。老实说,我不明白为什么 imshow()
将浮点数的二维数组转换为适当颜色的块,而我确实明白 scatter()
应该如何处理这样的数组。
关于使用 imshow 绘制 x/y 值的二维直方图的警告,如下所示:默认情况下,imshow 在左上角绘制原点并转置图像。为了获得与散点图相同的方向,我会做的是plt.imshow(heatmap.T, extent=extent, origin = 'lower')
对于那些想要做对数颜色条的人,请参阅这个问题 ***.com/questions/17201172/… 并简单地做 from matplotlib.colors import LogNorm
plt.imshow(heatmap, norm=LogNorm())
plt.colorbar()
【参考方案2】:
在 Matplotlib 词典中,我认为你想要一个 hexbin 图。
如果您不熟悉这种类型的图,它只是一个 二元直方图,其中 xy 平面由六边形的规则网格细分。
因此,从直方图中,您可以只计算落在每个六边形中的点数,将绘图区域离散化为一组 窗口,将每个点分配给其中一个窗口;最后,将窗口映射到一个颜色数组,你就得到了一个十六进制图。
虽然不像圆形或正方形那样常用,但六边形是分箱容器几何形状的更好选择是直观的:
六边形具有最近邻对称性(例如,方形箱没有, 例如,距离 from 正方形边界上的一个点 到 一个点 在那个正方形里面不是处处相等)和
六边形是给出正平面的最高n-多边形 镶嵌(即,您可以安全地使用六边形瓷砖重新塑造厨房地板,因为完成后瓷砖之间不会有任何空隙空间 - 对于所有其他更高的 n,n 并非如此>= 7,多边形)。
(Matplotlib 使用术语 hexbin 绘图;(AFAIK) 所有 plotting libraries 的 R 也是如此;我仍然没有'不知道这是否是这种类型的地块的普遍接受的术语,尽管我怀疑这可能是因为 hexbin 是 hexagonal binning 的缩写,它描述了准备要显示的数据。)
from matplotlib import pyplot as PLT
from matplotlib import cm as CM
from matplotlib import mlab as ML
import numpy as NP
n = 1e5
x = y = NP.linspace(-5, 5, 100)
X, Y = NP.meshgrid(x, y)
Z1 = ML.bivariate_normal(X, Y, 2, 2, 0, 0)
Z2 = ML.bivariate_normal(X, Y, 4, 1, 1, 1)
ZD = Z2 - Z1
x = X.ravel()
y = Y.ravel()
z = ZD.ravel()
gridsize=30
PLT.subplot(111)
# if 'bins=None', then color of each hexagon corresponds directly to its count
# 'C' is optional--it maps values to x-y coordinates; if 'C' is None (default) then
# the result is a pure 2D histogram
PLT.hexbin(x, y, C=z, gridsize=gridsize, cmap=CM.jet, bins=None)
PLT.axis([x.min(), x.max(), y.min(), y.max()])
cb = PLT.colorbar()
cb.set_label('mean value')
PLT.show()
【讨论】:
“六边形具有最近邻对称性”是什么意思?您说“从正方形边界上的一个点到该正方形内的一个点的距离并非处处相等”但是到什么距离? 对于六边形,中心到连接两侧的顶点的距离也比中心到边中的距离长,只是比值更小(六边形的2 / sqrt(3)≈1.15与 sqrt(2) ≈ 1.41 的平方)。从中心到边界上每个点的距离都相等的唯一形状是圆形。 @Jaan 对于六边形,每个邻居的距离相同。 8-neighborhood 或 4-neighborhood 没有问题。没有对角线邻居,只有一种邻居。 @doug 如何选择gridsize=
参数。我想这样选择它,这样六边形就可以接触而不会重叠。我注意到gridsize=100
会产生更小的六边形,但是如何选择合适的值呢?【参考方案3】:
编辑:为了更好地近似亚历杭德罗的答案,请参见下文。
我知道这是一个老问题,但想在 Alejandro 的回答中添加一些内容:如果您想要一个漂亮的平滑图像而不使用 py-sphviewer,您可以改用 np.histogram2d
并应用高斯滤波器(来自 scipy.ndimage.filters
)到热图:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.ndimage.filters import gaussian_filter
def myplot(x, y, s, bins=1000):
heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins)
heatmap = gaussian_filter(heatmap, sigma=s)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
return heatmap.T, extent
fig, axs = plt.subplots(2, 2)
# Generate some test data
x = np.random.randn(1000)
y = np.random.randn(1000)
sigmas = [0, 16, 32, 64]
for ax, s in zip(axs.flatten(), sigmas):
if s == 0:
ax.plot(x, y, 'k.', markersize=5)
ax.set_title("Scatter plot")
else:
img, extent = myplot(x, y, s)
ax.imshow(img, extent=extent, origin='lower', cmap=cm.jet)
ax.set_title("Smoothing with $\sigma$ = %d" % s)
plt.show()
生产:
Agape Gal'lo 的散点图和 s=16 相互重叠(点击查看更好的视图):
我注意到我的高斯滤波器方法和 Alejandro 方法的一个不同之处在于,他的方法比我的方法更能显示局部结构。因此,我在像素级别实现了一个简单的最近邻方法。此方法为每个像素计算数据中n
最近点的距离的倒数和。这种方法分辨率很高,计算成本很高,我认为有更快的方法,所以如果您有任何改进,请告诉我。
更新:正如我所怀疑的,使用 Scipy 的scipy.cKDTree
有一种更快的方法。实现见Gabriel's answer。
无论如何,这是我的代码:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
def data_coord2view_coord(p, vlen, pmin, pmax):
dp = pmax - pmin
dv = (p - pmin) / dp * vlen
return dv
def nearest_neighbours(xs, ys, reso, n_neighbours):
im = np.zeros([reso, reso])
extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]
xv = data_coord2view_coord(xs, reso, extent[0], extent[1])
yv = data_coord2view_coord(ys, reso, extent[2], extent[3])
for x in range(reso):
for y in range(reso):
xp = (xv - x)
yp = (yv - y)
d = np.sqrt(xp**2 + yp**2)
im[y][x] = 1 / np.sum(d[np.argpartition(d.ravel(), n_neighbours)[:n_neighbours]])
return im, extent
n = 1000
xs = np.random.randn(n)
ys = np.random.randn(n)
resolution = 250
fig, axes = plt.subplots(2, 2)
for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 64]):
if neighbours == 0:
ax.plot(xs, ys, 'k.', markersize=2)
ax.set_aspect('equal')
ax.set_title("Scatter Plot")
else:
im, extent = nearest_neighbours(xs, ys, resolution, neighbours)
ax.imshow(im, origin='lower', extent=extent, cmap=cm.jet)
ax.set_title("Smoothing over %d neighbours" % neighbours)
ax.set_xlim(extent[0], extent[1])
ax.set_ylim(extent[2], extent[3])
plt.show()
结果:
【讨论】:
喜欢这个。 Graph 和 Alejandro 的回答一样好,但不需要新的包。 非常好!但是您使用此方法生成偏移量。您可以通过将正常散点图与彩色散点图进行比较来看到这一点。你能添加一些东西来纠正它吗?或者只是将图形移动 x 和 y 值? Agape Gal'lo,偏移量是什么意思?如果您将它们绘制在彼此之上,它们确实匹配(请参阅我的帖子的编辑)。也许你被推迟了,因为散布的宽度与其他三个不完全匹配。 非常感谢您为我绘制图表!我明白我的错误:我修改了“范围”来定义 x 和 y 限制。我现在明白它修改了图表的起源。然后,我有最后一个问题:即使对于没有现有数据的区域,如何扩展图形的限制?例如,对于 x 和 y,介于 -5 到 +5 之间。 假设x轴从-5到5,y轴从-3到4;在myplot
函数中,将range
参数添加到np.histogram2d
:np.histogram2d(x, y, bins=bins, range=[[-5, 5], [-3, 4]])
并在for 循环中设置轴的x 和y 边界:ax.set_xlim([-5, 5])
ax.set_ylim([-3, 4])
。此外,默认情况下,imshow
保持纵横比与轴的比例相同(因此在我的示例中为 10:7),但如果您希望它与您的绘图窗口匹配,请将参数 aspect='auto'
添加到imshow
.【参考方案4】:
我不想使用 np.hist2d,它通常会产生非常丑陋的直方图,我想回收 py-sphviewer,这是一个使用自适应平滑内核渲染粒子模拟的 python 包,可以从 pip 轻松安装(参见网页文档)。考虑以下基于示例的代码:
import numpy as np
import numpy.random
import matplotlib.pyplot as plt
import sphviewer as sph
def myplot(x, y, nb=32, xsize=500, ysize=500):
xmin = np.min(x)
xmax = np.max(x)
ymin = np.min(y)
ymax = np.max(y)
x0 = (xmin+xmax)/2.
y0 = (ymin+ymax)/2.
pos = np.zeros([len(x),3])
pos[:,0] = x
pos[:,1] = y
w = np.ones(len(x))
P = sph.Particles(pos, w, nb=nb)
S = sph.Scene(P)
S.update_camera(r='infinity', x=x0, y=y0, z=0,
xsize=xsize, ysize=ysize)
R = sph.Render(S)
R.set_logscale()
img = R.get_image()
extent = R.get_extent()
for i, j in zip(xrange(4), [x0,x0,y0,y0]):
extent[i] += j
print extent
return img, extent
fig = plt.figure(1, figsize=(10,10))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)
# Generate some test data
x = np.random.randn(1000)
y = np.random.randn(1000)
#Plotting a regular scatter plot
ax1.plot(x,y,'k.', markersize=5)
ax1.set_xlim(-3,3)
ax1.set_ylim(-3,3)
heatmap_16, extent_16 = myplot(x,y, nb=16)
heatmap_32, extent_32 = myplot(x,y, nb=32)
heatmap_64, extent_64 = myplot(x,y, nb=64)
ax2.imshow(heatmap_16, extent=extent_16, origin='lower', aspect='auto')
ax2.set_title("Smoothing over 16 neighbors")
ax3.imshow(heatmap_32, extent=extent_32, origin='lower', aspect='auto')
ax3.set_title("Smoothing over 32 neighbors")
#Make the heatmap using a smoothing over 64 neighbors
ax4.imshow(heatmap_64, extent=extent_64, origin='lower', aspect='auto')
ax4.set_title("Smoothing over 64 neighbors")
plt.show()
生成以下图像:
如您所见,图像看起来非常漂亮,我们能够识别其上的不同子结构。这些图像被构造为在某个域内为每个点传播给定的权重,由平滑长度定义,而平滑长度又由到较近的 nb 邻居的距离给出(我选择了 16、32和 64 的例子)。因此,与密度较低的区域相比,密度较高的区域通常分布在较小的区域中。
函数 myplot 只是我编写的一个非常简单的函数,用于将 x,y 数据提供给 py-sphviewer 以发挥作用。
【讨论】:
对任何试图在 OSX 上安装 py-sphviewer 的人的评论:我遇到了很多困难,请参阅:github.com/alejandrobll/py-sphviewer/issues/3 太糟糕了,它不适用于 python3。它会安装,但当您尝试使用它时会崩溃... @Fabio Dias,最新版本 (1.1.x) 现在适用于 Python 3。ValueError: Max 127 dimensions allowed
使用pip3 install py-sphviewer
和上面的代码。 Python 3.8.6
示例中显示的代码已被弃用。我对其进行了更新,使其适用于新版本。【参考方案5】:
如果您使用的是 1.2.x
import numpy as np
import matplotlib.pyplot as plt
x = np.random.randn(100000)
y = np.random.randn(100000)
plt.hist2d(x,y,bins=100)
plt.show()
【讨论】:
【参考方案6】:Seaborn 现在有了jointplot function,在这里应该可以很好地工作:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)
sns.jointplot(x=x, y=y, kind='hex')
plt.show()
【讨论】:
@wordsforthewise 你如何使用它使 600k 数据在视觉上可读? (如何调整大小) 我不太清楚你的意思;也许最好您提出一个单独的问题并将其链接到此处。你的意思是调整整个无花果的大小?先用fig = plt.figure(figsize=(12, 12))
制作图形,然后用ax=plt.gca()
获取当前坐标轴,然后在jointplot
函数中添加参数ax=ax
。【参考方案7】:
这里是 Jurgy's great nearest neighbour approach,但使用 scipy.cKDTree 实现。在我的测试中,它快了大约 100 倍。
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.spatial import cKDTree
def data_coord2view_coord(p, resolution, pmin, pmax):
dp = pmax - pmin
dv = (p - pmin) / dp * resolution
return dv
n = 1000
xs = np.random.randn(n)
ys = np.random.randn(n)
resolution = 250
extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]
xv = data_coord2view_coord(xs, resolution, extent[0], extent[1])
yv = data_coord2view_coord(ys, resolution, extent[2], extent[3])
def kNN2DDens(xv, yv, resolution, neighbours, dim=2):
"""
"""
# Create the tree
tree = cKDTree(np.array([xv, yv]).T)
# Find the closest nnmax-1 neighbors (first entry is the point itself)
grid = np.mgrid[0:resolution, 0:resolution].T.reshape(resolution**2, dim)
dists = tree.query(grid, neighbours)
# Inverse of the sum of distances to each grid point.
inv_sum_dists = 1. / dists[0].sum(1)
# Reshape
im = inv_sum_dists.reshape(resolution, resolution)
return im
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 63]):
if neighbours == 0:
ax.plot(xs, ys, 'k.', markersize=5)
ax.set_aspect('equal')
ax.set_title("Scatter Plot")
else:
im = kNN2DDens(xv, yv, resolution, neighbours)
ax.imshow(im, origin='lower', extent=extent, cmap=cm.Blues)
ax.set_title("Smoothing over %d neighbours" % neighbours)
ax.set_xlim(extent[0], extent[1])
ax.set_ylim(extent[2], extent[3])
plt.savefig('new.png', dpi=150, bbox_inches='tight')
【讨论】:
我知道我的实现效率很低,但不知道 cKDTree。做得好!我会在回答中引用你。【参考方案8】:最初的问题是……如何将散点值转换为网格值,对吗?
histogram2d
确实计算了每个单元格的频率,但是,如果每个单元格还有其他数据而不仅仅是频率,则需要做一些额外的工作。
x = data_x # between -10 and 4, log-gamma of an svc
y = data_y # between -4 and 11, log-C of an svc
z = data_z #between 0 and 0.78, f1-values from a difficult dataset
所以,我有一个数据集,其中包含 X 和 Y 坐标的 Z 结果。但是,我正在计算感兴趣区域之外的几个点(大间隙),以及一小部分感兴趣区域中的大量点。
是的,这里变得更难但也更有趣。一些图书馆(对不起):
from matplotlib import pyplot as plt
from matplotlib import cm
import numpy as np
from scipy.interpolate import griddata
pyplot 是我今天的图形引擎, cm 是一系列具有一些有趣选择的彩色地图。 numpy 用于计算, 和 griddata 用于将值附加到固定网格。
最后一个很重要,尤其是因为 xy 点的频率在我的数据中分布不均。首先,让我们从适合我的数据和任意网格大小的边界开始。原始数据的数据点也在 x 和 y 边界之外。
#determine grid boundaries
gridsize = 500
x_min = -8
x_max = 2.5
y_min = -2
y_max = 7
所以我们定义了一个网格,在 x 和 y 的最小值和最大值之间有 500 个像素。
在我的数据中,有很多超过 500 个值可用于高关注区域;而在低兴趣区域,整个网格中甚至没有 200 个值;在x_min
和x_max
的图形边界之间,就更少了。
因此,为了获得一张漂亮的照片,任务是获取高兴趣值的平均值并填补其他地方的空白。
我现在定义我的网格。对于每个 xx-yy 对,我想要一个颜色。
xx = np.linspace(x_min, x_max, gridsize) # array of x values
yy = np.linspace(y_min, y_max, gridsize) # array of y values
grid = np.array(np.meshgrid(xx, yy.T))
grid = grid.reshape(2, grid.shape[1]*grid.shape[2]).T
为什么是奇怪的形状? scipy.griddata 想要 (n, D) 的形状。
Griddata 通过预定义的方法计算网格中每个点的一个值。 我选择“最近” - 空网格点将填充最近邻居的值。这看起来好像信息较少的区域具有更大的单元格(即使不是这样)。可以选择插入“线性”,然后信息较少的区域看起来不那么清晰。口味问题,真的。
points = np.array([x, y]).T # because griddata wants it that way
z_grid2 = griddata(points, z, grid, method='nearest')
# you get a 1D vector as result. Reshape to picture format!
z_grid2 = z_grid2.reshape(xx.shape[0], yy.shape[0])
然后跳,我们交给matplotlib来显示情节
fig = plt.figure(1, figsize=(10, 10))
ax1 = fig.add_subplot(111)
ax1.imshow(z_grid2, extent=[x_min, x_max,y_min, y_max, ],
origin='lower', cmap=cm.magma)
ax1.set_title("SVC: empty spots filled by nearest neighbours")
ax1.set_xlabel('log gamma')
ax1.set_ylabel('log C')
plt.show()
在 V 形的尖部周围,您会看到我在寻找最佳位置的过程中进行了很多计算,而几乎其他地方不太有趣的部分的分辨率较低。
【讨论】:
您能否改进您的答案以获得完整且可运行的代码?这是您提供的一个有趣的方法。我正在努力更好地理解它。我也不太明白为什么会有V形。谢谢。 V 形来自我的数据。它是经过训练的 SVM 的 f1 值:这在 SVM 的理论中有所体现。如果您的 C 值较高,则它会将您的所有点都包含在计算中,从而允许更广泛的伽马范围工作。 Gamma 是区分好坏曲线的刚度。这两个值必须提供给 SVM(我的图形中的 X 和 Y);然后你会得到一个结果(我的图形中的 Z)。在最好的区域,您有望达到有意义的高度。 第二次尝试:V 形在我的数据中。它是 SVM 的 f1 值:如果 C 值很高,它会在计算中包含所有点,从而允许更广泛的 gamma 范围工作,但会使计算变慢。 Gamma 是区分好坏曲线的刚度。这两个值必须提供给 SVM(我的图形中的 X 和 Y);然后你会得到一个结果(我的图形中的 Z)。在优化区域,您获得高值,在其他地方获得低值。如果你有一些 (X, Y) 的 Z 值和其他地方的许多间隙,我在这里展示的内容是可用的。如果你有 (X,Y,Z) 数据点,你可以使用我的代码。【参考方案9】:创建一个与最终图像中的单元格对应的二维数组,称为heatmap_cells
,并将其实例化为全零。
为每个维度选择两个缩放因子,以定义每个数组元素之间的差异(以实际单位表示),例如x_scale
和y_scale
。选择这些以使您的所有数据点都落在热图数组的范围内。
对于每个带有x_value
和y_value
的原始数据点:
heatmap_cells[floor(x_value/x_scale),floor(y_value/y_scale)]+=1
【讨论】:
【参考方案10】:与@Piti's answer 非常相似,但使用 1 次调用而不是 2 次来生成积分:
import numpy as np
import matplotlib.pyplot as plt
pts = 1000000
mean = [0.0, 0.0]
cov = [[1.0,0.0],[0.0,1.0]]
x,y = np.random.multivariate_normal(mean, cov, pts).T
plt.hist2d(x, y, bins=50, cmap=plt.cm.jet)
plt.show()
输出:
【讨论】:
【参考方案11】:这是我用 3 个类别(红色、绿色和蓝色)的 100 万分集制作的。如果您想尝试该功能,这里是存储库的链接。 Github Repo
histplot(
X,
Y,
labels,
bins=2000,
range=((-3,3),(-3,3)),
normalize_each_label=True,
colors = [
[1,0,0],
[0,1,0],
[0,0,1]],
gain=50)
【讨论】:
【参考方案12】:恐怕我参加聚会有点晚了,但不久前我也有类似的问题。接受的答案(@ptomato)帮助了我,但我也想发布这个以防它对某人有用。
''' I wanted to create a heatmap resembling a football pitch which would show the different actions performed '''
import numpy as np
import matplotlib.pyplot as plt
import random
#fixing random state for reproducibility
np.random.seed(1234324)
fig = plt.figure(12)
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
#Ratio of the pitch with respect to UEFA standards
hmap= np.full((6, 10), 0)
#print(hmap)
xlist = np.random.uniform(low=0.0, high=100.0, size=(20))
ylist = np.random.uniform(low=0.0, high =100.0, size =(20))
#UEFA Pitch Standards are 105m x 68m
xlist = (xlist/100)*10.5
ylist = (ylist/100)*6.5
ax1.scatter(xlist,ylist)
#int of the co-ordinates to populate the array
xlist_int = xlist.astype (int)
ylist_int = ylist.astype (int)
#print(xlist_int, ylist_int)
for i, j in zip(xlist_int, ylist_int):
#this populates the array according to the x,y co-ordinate values it encounters
hmap[j][i]= hmap[j][i] + 1
#Reversing the rows is necessary
hmap = hmap[::-1]
#print(hmap)
im = ax2.imshow(hmap)
这是结果
【讨论】:
以上是关于使用散点数据集在 MatPlotLib 中生成热图的主要内容,如果未能解决你的问题,请参考以下文章