如何在绘制多个热图时修复plt.tight_layout()错误

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何在绘制多个热图时修复plt.tight_layout()错误相关的知识,希望对你有一定的参考价值。

我正在绘制多个热图,并且不希望它们用轴的标签挤压空间。我试着使用plt.tight_layout(),这给我带来了一些错误。

这是我的结果没有plt.tight_layout():heatmaps without plt.tight_layout()

如何在没有重叠/过度拥挤的情况下将这些热图绘制在一起?这些plt.tight_layout()错误意味着什么?

这是我得到的错误:

Traceback (most recent call last):
  File "C:/Users/mbsta/Desktop/OOL_Research/heatmaptrying.py", line 125, in <module>
    plt.tight_layout()
  File "C:UsersmbstaAnaconda3libsite-packagesmatplotlibpyplot.py", line 1406, in tight_layout
    fig.tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect)
  File "C:UsersmbstaAnaconda3libsite-packagesmatplotlibfigure.py", line 1755, in tight_layout
    self.subplots_adjust(**kwargs)
  File "C:UsersmbstaAnaconda3libsite-packagesmatplotlibfigure.py", line 1633, in subplots_adjust
    ax.update_params()
  File "C:UsersmbstaAnaconda3libsite-packagesmatplotlibaxes\_subplots.py", line 114, in update_params
    return_all=True)
  File "C:UsersmbstaAnaconda3libsite-packagesmatplotlibgridspec.py", line 429, in get_position
    gridspec.get_grid_positions(fig)
  File "C:UsersmbstaAnaconda3libsite-packagesmatplotlibgridspec.py", line 90, in get_grid_positions
    subplot_params = self.get_subplot_params(fig)
  File "C:UsersmbstaAnaconda3libsite-packagesmatplotlibgridspec.py", line 376, in get_subplot_params
    hspace=hspace)
  File "C:UsersmbstaAnaconda3libsite-packagesmatplotlibfigure.py", line 193, in __init__
    self.update(left, bottom, right, top, wspace, hspace)
  File "C:UsersmbstaAnaconda3libsite-packagesmatplotlibfigure.py", line 228, in update
    raise ValueError('left cannot be >= right')
ValueError: left cannot be >= right

这是我的代码:

import matplotlib.pyplot as plt
import numpy as np
import copy

#open a matplotlib Heatmap plt
ensamblePlotTitle = 'Plotting the concentration of species as a function of Temperature and pH'
fig = plt.figure(figsize=plt.figaspect(1))
# the whole titlebox's title
fig.suptitle(ensamblePlotTitle)

# create x, y -> arrays that define the axes of the plots
def doit(x_min, x_max, y_min, y_max):
    x = list(range(x_min, x_max + 10 , 10))
    xLen = len(x)
    y = list(range(y_min, y_max + 1, 1))
    yLen = len(y)
    return x, y

# create an empty matrix that is of proper dimension for the data
def emptyMatrix(x, y):
    matrix = []
    for row in y:
        row_temp = []
        for col in x:
            index = [col, row]
            row_temp.append(index)
        matrix.append(row_temp)
    return matrix

#below is some sample data. all data is labeled ('speciesName ') same right now for test
speciesName = 'Gly5'
vectorList1 = [[90.0, 9.0, 2.3984407629124322e-05], [100.0, 9.0, 0.00014507819388370692], [100.0, 8.0, 2.948628706668523e-05], [100.0, 3.0, 1.0490510000464058e-05]]
# speciesName AGG
vectorList2 = [[90.0, 11.0, 3.98338670007889e-05], [90.0, 10.0, 0.0001568694602997819], [90.0, 9.0, 0.00021729500208826397], [90.0, 8.0, 3.427954893498538e-05], [80.0, 11.0, 5.005800733212678e-06], [80.0, 10.0, 1.8917815211842776e-05], [80.0, 9.0, 3.260197688987888e-05], [100.0, 10.0, 0.0025925430414404384], [100.0, 8.0, 0.00012735904218293194], [100.0, 7.0, 8.764211796371061e-06], [100.0, 3.0, 6.556174300431575e-05], [100.0, 2.0, 7.130029235695392e-05], [100.0, 1.0, 2.8579516450879392e-05]]
# speciesName AG
vectorList3 = [[90.0, 12.0, 7.857065300006515e-05], [90.0, 11.0, 0.00018731032885259906], [90.0, 10.0, 0.0003595914496876516], [90.0, 9.0, 0.00016362990148175496], [90.0, 8.0, 2.834531332566035e-05], [90.0, 7.0, 1.8552701108239776e-06], [90.0, 3.0, 4.299622865498346e-06], [90.0, 2.0, 2.588271913251826e-05], [80.0, 12.0, 2.9946940577790324e-05], [80.0, 11.0, 7.045827994817115e-05], [80.0, 10.0, 0.0001310902154950886], [80.0, 9.0, 5.4234692030952545e-05], [70.0, 12.0, 2.2564659384568613e-05], [70.0, 11.0, 4.064178013275714e-05], [70.0, 10.0, 3.99009750483181e-05], [70.0, 9.0, 2.1472055129680704e-05], [70.0, 8.0, 1.4141458012117527e-06], [70.0, 3.0, 6.539410917356151e-07], [60.0, 12.0, 8.517883124497818e-06], [60.0, 11.0, 9.11130896798338e-06], [60.0, 10.0, 8.557261467856704e-06], [60.0, 9.0, 3.852707623039227e-06], [50.0, 12.0, 3.847351010155849e-06], [50.0, 11.0, 3.5498418351466194e-06], [50.0, 10.0, 2.8449405343583283e-06], [100.0, 12.0, 0.00014641664314100198], [100.0, 11.0, 0.00039152554887185394], [100.0, 9.0, 0.00043717011589104363], [100.0, 8.0, 5.702172323684191e-05], [100.0, 7.0, 4.194951754290719e-06], [100.0, 3.0, 4.116788638189755e-05], [100.0, 2.0, 0.00011746024162667303]]
# speciesName AA &/or GA
vectorList4 = [[90.0, 12.0, 5.096802681201908e-05], [90.0, 11.0, 0.0001086474552469471], [90.0, 10.0, 0.0001874050106046458], [90.0, 9.0, 3.369801732936655e-05], [90.0, 9.0, 5.286861097237001e-05], [90.0, 8.0, 1.3938341042512686e-05], [90.0, 2.0, 1.7306926679551494e-05], [80.0, 12.0, 1.889885411916292e-05], [80.0, 11.0, 4.0157802920077886e-05], [80.0, 10.0, 6.100950436853497e-05], [80.0, 9.0, 1.9450730019472016e-05], [80.0, 2.0, 8.24310336091265e-06], [70.0, 12.0, 1.3419183912788552e-05], [70.0, 11.0, 2.1882921815167904e-05], [70.0, 10.0, 1.996293802977987e-05], [70.0, 9.0, 7.807190890862633e-06], [100.0, 12.0, 9.576002374282467e-05], [100.0, 11.0, 0.0002486277660752677], [100.0, 10.0, 0.00048311682483152005], [100.0, 8.0, 1.1431446213091852e-05], [100.0, 3.0, 2.3901930552237836e-05], [100.0, 2.0, 8.24111996641259e-05]]
# speciesName GGA
vectorList5 = [[90.0, 10.0, 4.6398440762912434e-05], [90.0, 9.0, 5.527820316488004e-05], [100.0, 12.0, 1.0628799480254304e-05], [100.0, 11.0, 6.78597614738503e-05], [100.0, 10.0, 0.0002986765047101954], [100.0, 9.0, 0.0002539820873358393], [100.0, 8.0, 3.658684857766022e-05], [100.0, 3.0, 2.1463641004222935e-05], [100.0, 2.0, 3.2324005754327344e-05]]

#these are min and max values from data that will define bounds on axes.
x_min = 40
x_max = 100
y_min = 1
y_max = 12

# create x, y => arrays that will define subplot axes.
x, y = doit(x_min, x_max, y_min, y_max)
# print(x)
# print(y)

# create an empty matrix of the proper dimensions for the data/subplots
matrix = emptyMatrix(x, y)
# for i in matrix:
#     print(i)

# create a heatmap subplot for each dataset (vectorList)
def createHeatPlot(count, speciesnum, matrix, vectorList,x,y):
    # 5 plots per row => 5 columns of plots
    # if more than 5 plots attempted, go to a new row
    num_columns = 5
    if speciesnum % num_columns == 0:
        num_rows = int(speciesnum / num_columns)
    else:
        num_rows = int(speciesnum / num_columns) + 1
    # print('looking at this species:', eachSpecies)

    #
    concentrationMatrix = copy.deepcopy(matrix)
    concentrationVector = copy.deepcopy(vectorList)
    for i, row in enumerate(concentrationMatrix):
        for j, element in enumerate(row):
            # compare element [temp, ph] to concentrationVector, see if in element of concentrationVector we hae matching [temp, pH]
            # if element doesn't match then no concentrated detected ( => concentration of zero), change element from  [temp, ph] to [0].
            # if element does match, change element from [temp, ph] to [concentration => animal[2]]

            # print('looking at this element', element)
            vectorMatch = 'No'
            for animal in concentrationVector:
                # print('looking at this animal', animal)
                if animal[0]==element[0] and animal[1]==element[1]:
                    vectorMatch = 'Yes'
                    # print('vectorMatch found with this animal', animal)
                    concentrationMatrix[i][j] = animal[2]
            if vectorMatch == 'No':
                # print('NO vectorMatch found with this element')
                concentrationMatrix[i][j] = 0

    # roster position of plot across the grid
    plotLocation = count + 1

    ax = fig.add_subplot(num_rows, num_columns, plotLocation)

    # give the plot axis some labels
    ax.set_title(speciesName)
    ax.set_xlabel('Temperature')
    ax.set_ylabel('pH')
    intensity = concentrationMatrix

    #setup the 2D grid with Numpy
    x, y = np.meshgrid(x, y)

    #convert intensity (list of lists) to a numpy array for plotting
    intensity = np.array(intensity)

    #now just plug the data into pcolormesh, it's that easy!
    plt.pcolormesh(x, y, intensity)
    plt.colorbar() #need a colorbar to show the intensity scale

speciesnum = 5
createHeatPlot(0, speciesnum, matrix, vectorList1,x,y)
createHeatPlot(1, speciesnum, matrix, vectorList2,x,y)
createHeatPlot(2, speciesnum, matrix, vectorList3,x,y)
createHeatPlot(3, speciesnum, matrix, vectorList4,x,y)
createHeatPlot(4, speciesnum, matrix, vectorList5,x,y)

# This helps prevent labels of one plot from overlapping those of other plots
plt.tight_layout()
plt.show()  # boom
答案

当然在这种情况下,Minimal, Complete, and Verifiable example是必不可少的。

import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize=(4.8,4.8))
for i in range(5):
    ax = fig.add_subplot(1, 5, i+1)
    ax.set_title("Title {}".format(i))
    ax.set_xlabel('Temperature')
    ax.set_ylabel('pH')
    x, y = np.meshgrid(np.arange(6),np.arange(11))
    intensity = np.random.rand(10,5)
    pc = ax.pcolormesh(x, y, intensity)
    fig.colorbar(pc, ax=ax)

plt.tight_layout()
plt.show()

以上产生了问题的错误。 原因很简单,图中的元素不适合图中。选择稍大的数字大小会使这一点变得明显,将数字宽度从4.8plt.figaspect(1)的默认值)更改为4.9

fig = plt.figure(figsize=(4.9,4.8))

产生

enter image description here

轴范围基本上为0;但不完全是因为这会导致上述错误。但是,输出还是没用的。因此需要更大的图形尺寸。例如。将图形宽度改为10英寸,

 fig = plt.figure(figsize=(10,4.8))

产生合理的结果。

enter image description here

以上是关于如何在绘制多个热图时修复plt.tight_layout()错误的主要内容,如果未能解决你的问题,请参考以下文章

使用“ gridGraphics”软件包绘制多个热图

两个或多个热图的 Python Plotly 悬停信息

使用 ggplot 绘制 300 多个单位的热图

网格搜索后如何在 pivot_table 上绘制热图

一起来学习如何使用R语言绘制热图

如何用R语言绘制热图