如何使用 plt.imshow 和 torchvision.utils.make_grid 在 PyTorch 中生成和显示图像网格?
Posted
技术标签:
【中文标题】如何使用 plt.imshow 和 torchvision.utils.make_grid 在 PyTorch 中生成和显示图像网格?【英文标题】:How can I generate and display a grid of images in PyTorch with plt.imshow and torchvision.utils.make_grid? 【发布时间】:2018-12-22 01:45:19 【问题描述】:我试图了解 torchvision 如何与 mathplotlib 交互以生成图像网格。生成图像并迭代显示它们很容易:
import torch
import torchvision
import matplotlib.pyplot as plt
w = torch.randn(10,3,640,640)
for i in range (0,10):
z = w[i]
plt.imshow(z.permute(1,2,0))
plt.show()
但是,在网格中显示这些图像似乎并不那么简单。
w = torch.randn(10,3,640,640)
grid = torchvision.utils.make_grid(w, nrow=5)
plt.imshow(grid)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-61-1601915e10f3> in <module>()
1 w = torch.randn(10,3,640,640)
2 grid = torchvision.utils.make_grid(w, nrow=5)
----> 3 plt.imshow(grid)
/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, hold, data, **kwargs)
3203 filternorm=filternorm, filterrad=filterrad,
3204 imlim=imlim, resample=resample, url=url, data=data,
-> 3205 **kwargs)
3206 finally:
3207 ax._hold = washold
/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, *args, **kwargs)
1853 "the Matplotlib list!)" % (label_namer, func.__name__),
1854 RuntimeWarning, stacklevel=2)
-> 1855 return func(ax, *args, **kwargs)
1856
1857 inner.__doc__ = _add_data_doc(inner.__doc__,
/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
5485 resample=resample, **kwargs)
5486
-> 5487 im.set_data(X)
5488 im.set_alpha(alpha)
5489 if im.get_clip_path() is None:
/anaconda3/lib/python3.6/site-packages/matplotlib/image.py in set_data(self, A)
651 if not (self._A.ndim == 2
652 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
--> 653 raise TypeError("Invalid dimensions for image data")
654
655 if self._A.ndim == 3:
TypeError: Invalid dimensions for image data
尽管 PyTorch 的文档表明 w 是正确的形状,但 Python 说它不是。所以我试图置换我的张量的索引:
w = torch.randn(10,3,640,640)
grid = torchvision.utils.make_grid(w.permute(0,2,3,1), nrow=5)
plt.imshow(grid)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-62-6f2dc6313e29> in <module>()
1 w = torch.randn(10,3,640,640)
----> 2 grid = torchvision.utils.make_grid(w.permute(0,2,3,1), nrow=5)
3 plt.imshow(grid)
/anaconda3/lib/python3.6/site-packages/torchvision-0.2.1-py3.6.egg/torchvision/utils.py in make_grid(tensor, nrow, padding, normalize, range, scale_each, pad_value)
83 grid.narrow(1, y * height + padding, height - padding)\
84 .narrow(2, x * width + padding, width - padding)\
---> 85 .copy_(tensor[k])
86 k = k + 1
87 return grid
RuntimeError: The expanded size of the tensor (3) must match the existing size (640) at non-singleton dimension 0
这里发生了什么?如何将一堆随机生成的图像放入网格中并显示?
【问题讨论】:
【参考方案1】:您的代码中有一个小错误。 torchvision.utils.make_grid() 返回一个包含图像网格的张量。但是通道维度必须移到最后,因为这是 matplotlib 识别的。以下是运行良好的代码:
In [107]: import torchvision
# sample input (10 RGB images containing just Gaussian Noise)
In [108]: batch_tensor = torch.randn(*(10, 3, 256, 256)) # (N, C, H, W)
# make grid (2 rows and 5 columns) to display our 10 images
In [109]: grid_img = torchvision.utils.make_grid(batch_tensor, nrow=5)
# check shape
In [110]: grid_img.shape
Out[110]: torch.Size([3, 518, 1292])
# reshape and plot (because MPL needs channel as the last dimension)
In [111]: plt.imshow(grid_img.permute(1, 2, 0))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[111]: <matplotlib.image.AxesImage at 0x7f62081ef080>
将输出显示为:
【讨论】:
谢谢你,kmario23。我的错误是没有将网格视为要显示的图像,这意味着必须重新调整网格的形状:-) 嗨@kmario23,grid_img.permute(1, 2, 0)
在这里做什么?这里的 1、2、0 是什么?请解释一下好吗?
@Md.MusfiqurRahaman,如in [110] grid_img.shape
所示,grid_img的尺寸为[#颜色通道x图像高度x图像宽度]。相反,matplotlib.pyplot.imshow() 的输入需要为 [图像高度 x 图像宽度 x # 颜色通道](即,形状需要为 [518, 1292, 3]
)。 .permute(1, 2, 0)
动作是 Torch 特有的函数,它完全按照以下顺序排列原始轴:[轴 1 x 轴 2 x 轴 0] = [图像高度 x 图像宽度 x # 颜色通道]。【参考方案2】:
你必须先转换成 numpy
import numpy as np
def show(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
w = torch.randn(10,3,640,640)
grid = torchvision.utils.make_grid(w, nrow=10, padding=100)
show(grid)
【讨论】:
嗨@iacolippo,(1, 2, 0) 在这里做什么?这里的 1、2、0 是什么?请解释一下好吗? 只是转换图像尺寸以将颜色通道放在最后 - 即从(color, width, height)
变为 (width, height, color)
以上是关于如何使用 plt.imshow 和 torchvision.utils.make_grid 在 PyTorch 中生成和显示图像网格?的主要内容,如果未能解决你的问题,请参考以下文章
plt.imshow() 和 plt.show() 没有图像弹出或显示