使用 matplotlib 绘制大量点并耗尽内存

Posted

技术标签:

【中文标题】使用 matplotlib 绘制大量点并耗尽内存【英文标题】:Plotting a large number of points using matplotlib and running out of memory 【发布时间】:2013-12-13 13:55:22 【问题描述】:

我有一个大的 (~6GB) 简单格式的文本文件

x1 y1 z1
x2 y2 z2
...

由于我可能会多次加载此数据,因此出于效率原因,我创建了一个np.memmap 文件:

X,Y,Z = np.memmap(f_np_mmap,dtype='float32',mode='r',shape=shape).T

我想做的是情节:

plt.scatter(X, Y, 
           color=custom_colorfunction(Z), 
           alpha=.01, s=.001, marker='s', linewidth=0)

这非常适用于较小的数据集。但是,对于这个更大的数据集,我的内存不足。我检查了plt.scatter 占用了所有内存;我可以通过X,Y,Z 就好了。有没有办法让我“光栅化”画布,以免内存不足?我不需要缩放和平移图像,它会直接进入磁盘。我意识到我可以对数据进行分类并绘制它,但我不确定如何使用自定义颜色图 alpha 值来做到这一点。

【问题讨论】:

matplotlib 出于自卫的原因制作数据的内部副本(如果它只是保留一个引用,则数据可能/将在其下更改)。我会考虑直接使用PathCollection(或它在下面使用的)。 另一种选择是编写Axes 的自定义子类,它覆盖draw 函数并为每个点生成一个艺术家,将其栅格化并合成下来,然后扔掉艺术家在制作下一个之前离开。 @tcaswell 第一种方法只有在内部表示是问题时才有帮助,但它不能解决潜在的大小问题。您的第二个解决方案很有趣,我如何栅格化/合成艺术家? 我认为内部表示是问题所在,在某个地方有 _internal_data = np.array(input_data) 的道德等价物,如果你是颜色映射,散射对象最终将至少为 6*N*64 B ( N 行,6 或 7 个浮点数(XYZ,RGB(可能是 A),64 位) 给我一点时间看看我能不能让它工作,但是看看matplotlib.axes.Axes.draw的代码 【参考方案1】:

类似这样的(对不起,代码很长,大部分是从标准axes.Axes.draw复制的):

from operator import itemgetter
class generator_scatter_axes(matplotlib.axes.Axes):
    def __init__(self, *args, **kwargs):
        matplotlib.axes.Axes.__init__(self, *args, **kwargs)
        self._big_data = None
    def draw(self, renderer=None, inframe=None):
        # copied from original draw (so you can still add normal artists ect)
        if renderer is None:
            renderer = self._cachedRenderer

        if renderer is None:
            raise RuntimeError('No renderer defined')
        if not self.get_visible():
            return
        renderer.open_group('axes')

        locator = self.get_axes_locator()
        if locator:
            pos = locator(self, renderer)
            self.apply_aspect(pos)
        else:
            self.apply_aspect()


        artists = []

        artists.extend(self.collections)
        artists.extend(self.patches)
        artists.extend(self.lines)
        artists.extend(self.texts)
        artists.extend(self.artists)
        if self.axison and not inframe:
            if self._axisbelow:
                self.xaxis.set_zorder(0.5)
                self.yaxis.set_zorder(0.5)
            else:
                self.xaxis.set_zorder(2.5)
                self.yaxis.set_zorder(2.5)
            artists.extend([self.xaxis, self.yaxis])
        if not inframe:
            artists.append(self.title)
            artists.append(self._left_title)
            artists.append(self._right_title)
        artists.extend(self.tables)
        if self.legend_ is not None:
            artists.append(self.legend_)

        # the frame draws the edges around the axes patch -- we
        # decouple these so the patch can be in the background and the
        # frame in the foreground.
        if self.axison and self._frameon:
            artists.extend(self.spines.itervalues())

        if self.figure.canvas.is_saving():
            dsu = [(a.zorder, a) for a in artists]
        else:
            dsu = [(a.zorder, a) for a in artists
                   if not a.get_animated()]

        # add images to dsu if the backend support compositing.
        # otherwise, does the manaul compositing  without adding images to dsu.
        if len(self.images) <= 1 or renderer.option_image_nocomposite():
            dsu.extend([(im.zorder, im) for im in self.images])
            _do_composite = False
        else:
            _do_composite = True

        dsu.sort(key=itemgetter(0))

        # rasterize artists with negative zorder
        # if the minimum zorder is negative, start rasterization
        rasterization_zorder = self._rasterization_zorder
        if (rasterization_zorder is not None and
            len(dsu) > 0 and dsu[0][0] < rasterization_zorder):
            renderer.start_rasterizing()
            dsu_rasterized = [l for l in dsu if l[0] < rasterization_zorder]
            dsu = [l for l in dsu if l[0] >= rasterization_zorder]
        else:
            dsu_rasterized = []

        # the patch draws the background rectangle -- the frame below
        # will draw the edges
        if self.axison and self._frameon:
            self.patch.draw(renderer)

        if _do_composite:
            # make a composite image blending alpha
            # list of (mimage.Image, ox, oy)

            zorder_images = [(im.zorder, im) for im in self.images
                             if im.get_visible()]
            zorder_images.sort(key=lambda x: x[0])

            mag = renderer.get_image_magnification()
            ims = [(im.make_image(mag), 0, 0, im.get_alpha()) for z, im in zorder_images]

            l, b, r, t = self.bbox.extents
            width = mag * ((round(r) + 0.5) - (round(l) - 0.5))
            height = mag * ((round(t) + 0.5) - (round(b) - 0.5))
            im = mimage.from_images(height,
                                    width,
                                    ims)

            im.is_grayscale = False
            l, b, w, h = self.bbox.bounds
            # composite images need special args so they will not
            # respect z-order for now

            gc = renderer.new_gc()
            gc.set_clip_rectangle(self.bbox)
            gc.set_clip_path(mtransforms.TransformedPath(
                    self.patch.get_path(),
                    self.patch.get_transform()))

            renderer.draw_image(gc, round(l), round(b), im)
            gc.restore()

        if dsu_rasterized:
            for zorder, a in dsu_rasterized:
                a.draw(renderer)
            renderer.stop_rasterizing()

        for zorder, a in dsu:
            a.draw(renderer)
        ############################    
        # new bits
        ############################
        if self._big_data is not None:

            for x, y, z in self._big_data:
                # add the (single point) to the axes
                a = self.scatter(x, y, color='r',
                            alpha=1, s=10, marker='s', linewidth=0)
                # add the point, in Agg this will render + composite
                a.draw(renderer)
                # remove the artist from the axes, shouldn't let the render know
                a.remove()
                # delete the artist for good measure
                del a
        #######################
        # end new bits
        #######################    
        # again, from original to clean up
        renderer.close_group('axes')
        self._cachedRenderer = renderer

像这样使用它:

In [42]: fig = figure()

In [43]: ax = generator_scatter_axes(fig, [.1, .1, .8, .8])

In [44]: fig.add_axes(ax)
Out[44]: <__main__.generator_scatter_axes at 0x56fe090>

In [45]: ax._big_data = rand(500, 3)

In [46]: draw()

我更改了您的 scatter 函数,使其具有少量可见的形状。这将非常慢,因为您每次 都设置scatter 对象。我会获取您的数据的合理块并绘制它们,或者将对scatter 的调用替换为基础艺术家对象,或者使用乔的建议并仅更新单个艺术家。

【讨论】:

是否有可能在draw 中获得缩放级别,并根据缩放和dpi 绘制一个矩形(数据块的最小和最大数据)或绘制数据本身?像细节级别绘图这样我们可以绘制大数据,然后只有在放大后才能看到细节。也许我需要为此培养一位新艺术家。但想知道是否可以知道缩放级别。 在干时,Axes 知道它的视图限制是什么。看看Line2Ddraw 方法,它使用这些知识(在某些情况下)只查看数据的一个子集。【参考方案2】:

@tcaswell 建议重写 Axes.draw 方法绝对是最灵活的方法。

但是,您可以使用/滥用 blitting 来执行此操作,而无需子类化 Axes。每次只需使用draw_artist,无需恢复画布。

还有一个额外的技巧:我们需要一个特殊的save 方法,因为所有其他人都在保存之前绘制了画布,这将清除我们之前在其上绘制的所有内容。

此外,正如 tcaswell 所说,为每个项目调用 draw_artist 相当慢,因此对于大量点,您需要对输入数据进行分块。分块会显着加快速度,但这种方法总是比绘制单个 PathCollection 慢。

无论如何,这些答案中的任何一个都应该可以缓解您的记忆问题。这是一个简单的例子。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import _png
from itertools import izip

def main():
    # We'll be saving the figure's background, so let's make it transparent.
    fig, ax = plt.subplots(facecolor='none')

    # You'll have to know the extent of the input beforehand with this method.
    ax.axis([0, 10, 0, 10])

    # We need to draw the canvas before we start adding points.
    fig.canvas.draw()

    # This won't actually ever be drawn. We just need an artist to update.
    col = ax.scatter([5], [5], color=[0.1, 0.1, 0.1], alpha=0.3)

    for xy, color in datastream(int(1e6), chunksize=int(1e4)):
        col.set_offsets(xy)
        col.set_color(color)
        ax.draw_artist(col)

    save(fig, 'test.png')

def datastream(n, chunksize=1):
    """Returns a generator over "n" random xy positions and rgb colors."""
    for _ in xrange(n//chunksize):
        xy = 10 * np.random.random((chunksize, 2))
        color = np.random.random((chunksize, 3))
        yield xy, color

def save(fig, filename):
    """We have to work around `fig.canvas.print_png`, etc calling `draw`."""
    renderer = fig.canvas.renderer
    with open(filename, 'w') as outfile:
        _png.write_png(renderer._renderer.buffer_rgba(),
                       renderer.width, renderer.height,
                       outfile, fig.dpi)

main()

此外,您可能会注意到顶部和左侧的脊椎被拉过。您可以通过在保存之前重新绘制这两个脊椎(ax.draw_artist(ax.spines['top']) 等)来解决此问题。

【讨论】:

谢谢,这很好用。我只想为其他使用它的人添加,可以通过在fig, ax = plt.subplots(facecolor='none', dpi=700)调用not中设置dpi来获得更高分辨率的图像,因为到那时似乎为时已晚。 @Hooked - 好点!对图形的任何其他更改也是如此...基本上,所有内容都需要在初始调用fig.canvas.draw() 之前进行设置。之后,我们只是在“固定”的光栅图像上绘制。【参考方案3】:

只是为了扩展已接受的答案,“解决方法”保存功能似乎不再起作用,因为 write_png 的签名已更改。我的解决方法如下:

import numpy as np
from PIL import Image

def png_write(fig, filename):
    width, height = map(int, fig.get_size_inches() * fig.get_dpi())
    image = np.frombuffer(fig.canvas.tostring_argb(), dtype='uint8')
    image = image.reshape(width, height, 4)
    image = np.roll(image, -1, 2)
    Image.fromarray(image, 'RGBA').save(filename)

【讨论】:

以上是关于使用 matplotlib 绘制大量点并耗尽内存的主要内容,如果未能解决你的问题,请参考以下文章

Python使用matplotlib可视化聚类图使用encircle函数绘制多边形标定属于同一聚类簇的数据点并自定义每个聚类簇的背景色(Cluster Plot)

如何使用 JDBC 将大量数据加载到文件中而不会耗尽内存?

如何使用 Python Ray 在不耗尽内存的情况下并行处理大量数据?

如何在嵌入Qt环境的matplotlib中更快地绘制大量信号?

Matplotlib:用不同颜色绘制大量断开的线段

如何解决PHP里大量数据循环时内存耗尽的问题