每个单元格中带有文本的热图,带有 matplotlib 的 pyplot

Posted

技术标签:

【中文标题】每个单元格中带有文本的热图,带有 matplotlib 的 pyplot【英文标题】:Heatmap with text in each cell with matplotlib's pyplot 【发布时间】:2014-09-24 04:04:05 【问题描述】:

我使用matplotlib.pyplot.pcolor() 用 matplotlib 绘制热图:

import numpy as np
import matplotlib.pyplot as plt    

def heatmap(data, title, xlabel, ylabel):
    plt.figure()
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)
    plt.colorbar(c)

def main():
    title = "ROC's AUC"
    xlabel= "Timeshift"
    ylabel="Scales"
    data =  np.random.rand(8,12)
    heatmap(data, title, xlabel, ylabel)
    plt.show()

if __name__ == "__main__":
    main()

是否有任何方式在每个单元格中添加相应的值,例如:

(来自 Matlab 的Customizable Heat Maps)

(我目前的申请不需要额外的%,但我很想知道未来)

【问题讨论】:

【参考方案1】:

你需要通过调用axes.text()来添加所有的文本,这里是一个例子:

import numpy as np
import matplotlib.pyplot as plt    

title = "ROC's AUC"
xlabel= "Timeshift"
ylabel="Scales"
data =  np.random.rand(8,12)


plt.figure(figsize=(12, 6))
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)

def show_values(pc, fmt="%.2f", **kw):
    from itertools import izip
    pc.update_scalarmappable()
    ax = pc.get_axes()
    for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
        x, y = p.vertices[:-2, :].mean(0)
        if np.all(color[:3] > 0.5):
            color = (0.0, 0.0, 0.0)
        else:
            color = (1.0, 1.0, 1.0)
        ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)

show_values(c)

plt.colorbar(c)

输出:

【讨论】:

效果很好,谢谢!希望有一天显示数字将作为 pcolor() 中的一个选项提供。【参考方案2】:

您可以使用Seaborn,这是一个基于 matplotlib 的 Python 可视化库,提供了用于绘制有吸引力的统计图形的高级接口。

Heatmap example:

import seaborn as sns
sns.set()

flights_long = sns.load_dataset("flights")
flights = flights_long.pivot("month", "year", "passengers")

sns.heatmap(flights, annot=True, fmt="d")

# To display the heatmap 
import matplotlib.pyplot as plt
plt.show()

# To save the heatmap as a file:
fig = heatmap.get_figure()
fig.savefig('heatmap.pdf')

文档:https://seaborn.pydata.org/generated/seaborn.heatmap.html

【讨论】:

【参考方案3】:

如果有人对此感兴趣,下面是我用来模仿问题中包含的 Matlab 的可定制热图的图片的代码。

import numpy as np
import matplotlib.pyplot as plt


def show_values(pc, fmt="%.2f", **kw):
    '''
    Heatmap with text in each cell with matplotlib's pyplot
    Source: http://***.com/a/25074150/395857 
    By HYRY
    '''
    from itertools import izip
    pc.update_scalarmappable()
    ax = pc.get_axes()
    for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
        x, y = p.vertices[:-2, :].mean(0)
        if np.all(color[:3] > 0.5):
            color = (0.0, 0.0, 0.0)
        else:
            color = (1.0, 1.0, 1.0)
        ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)

def cm2inch(*tupl):
    '''
    Specify figure size in centimeter in matplotlib
    Source: http://***.com/a/22787457/395857
    By gns-ank
    '''
    inch = 2.54
    if type(tupl[0]) == tuple:
        return tuple(i/inch for i in tupl[0])
    else:
        return tuple(i/inch for i in tupl)

def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels):
    '''
    Inspired by:
    - http://***.com/a/16124677/395857 
    - http://***.com/a/25074150/395857
    '''

    # Plot it out
    fig, ax = plt.subplots()    
    c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)

    # put the major ticks at the middle of each cell
    ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)
    ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)

    # set tick labels
    #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)
    ax.set_xticklabels(xticklabels, minor=False)
    ax.set_yticklabels(yticklabels, minor=False)

    # set title and x/y labels
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)      

    # Remove last blank column
    plt.xlim( (0, AUC.shape[1]) )

    # Turn off all the ticks
    ax = plt.gca()    
    for t in ax.xaxis.get_major_ticks():
        t.tick1On = False
        t.tick2On = False
    for t in ax.yaxis.get_major_ticks():
        t.tick1On = False
        t.tick2On = False

    # Add color bar
    plt.colorbar(c)

    # Add text in each cell 
    show_values(c)

    # resize 
    fig = plt.gcf()
    fig.set_size_inches(cm2inch(40, 20))



def main():
    x_axis_size = 19
    y_axis_size = 10
    title = "ROC's AUC"
    xlabel= "Timeshift"
    ylabel="Scales"
    data =  np.random.rand(y_axis_size,x_axis_size)
    xticklabels = range(1, x_axis_size+1) # could be text
    yticklabels = range(1, y_axis_size+1) # could be text   
    heatmap(data, title, xlabel, ylabel, xticklabels, yticklabels)
    plt.savefig('image_output.png', dpi=300, format='png', bbox_inches='tight') # use format='svg' or 'pdf' for vectorial pictures
    plt.show()


if __name__ == "__main__":
    main()
    #cProfile.run('main()') # if you want to do some profiling

输出:

有一些图案的时候会更好看:

【讨论】:

为了让这个有用的代码 sn-p 保持最新:Python3 中删除了“izip”,所以只需删除导入并使用 zip; "get_axes()" 已被弃用 - 而是只写 ".axes" @Chaoste 谢谢!【参考方案4】:

同@HYRY aswer,但python3兼容版本:

import numpy as np
import matplotlib.pyplot as plt    

title = "ROC's AUC"
xlabel= "Timeshift"
ylabel="Scales"
data =  np.random.rand(8,12)


plt.figure(figsize=(12, 6))
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)

def show_values(pc, fmt="%.2f", **kw):
    pc.update_scalarmappable()
    ax = pc.axes
    for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
        x, y = p.vertices[:-2, :].mean(0)
        if np.all(color[:3] > 0.5):
            color = (0.0, 0.0, 0.0)
        else:
            color = (1.0, 1.0, 1.0)
        ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)

show_values(c)

plt.colorbar(c)

【讨论】:

以上是关于每个单元格中带有文本的热图,带有 matplotlib 的 pyplot的主要内容,如果未能解决你的问题,请参考以下文章

向热图中的特定单元格添加注释

重叠 yticklabels:是不是可以控制 seaborn 中热图的单元格大小?

在uitableview中带有空格填充的圆形单元格

一个 QTableView 单元格中带有图像的超链接

单元格中带有按钮的 Xamarin Telerik ListView - 按钮的单击事件不会触发

ggplot中带有圆圈而不是瓷砖的热图