如何在 PyTorch 中显示单个图像?
Posted
技术标签:
【中文标题】如何在 PyTorch 中显示单个图像?【英文标题】:How do I display a single image in PyTorch? 【发布时间】:2019-05-06 12:01:55 【问题描述】:我想显示使用 ImageLoader
加载并存储在 PyTorch Tensor
中的单个图像。当我尝试通过plt.imshow(image)
显示它时,我得到:
TypeError: Invalid dimensions for image data
张量的.shape
是:
torch.Size([3, 244, 244])
如何将 PyTorch 张量显示为图像?
【问题讨论】:
【参考方案1】:我编写了一个简单的函数来使用 matplotlib 可视化 pytorch 张量。
import numpy as np
import matplotlib.pyplot as plt
import torch
def show(*imgs):
'''
input imgs can be single or multiple tensor(s), this function uses matplotlib to visualize.
Single input example:
show(x) gives the visualization of x, where x should be a torch.Tensor
if x is a 4D tensor (like image batch with the size of b(atch)*c(hannel)*h(eight)*w(eight), this function splits x in batch dimension, showing b subplots in total, where each subplot displays first 3 channels (3*h*w) at most.
if x is a 3D tensor, this function shows first 3 channels at most (in RGB format)
if x is a 2D tensor, it will be shown as grayscale map
Multiple input example:
show(x,y,z) produces three windows, displaying x, y, z respectively, where x,y,z can be in any form described above.
'''
img_idx = 0
for img in imgs:
img_idx +=1
plt.figure(img_idx)
if isinstance(img, torch.Tensor):
img = img.detach().cpu()
if img.dim()==4: # 4D tensor
bz = img.shape[0]
c = img.shape[1]
if bz==1 and c==1: # single grayscale image
img=img.squeeze()
elif bz==1 and c==3: # single RGB image
img=img.squeeze()
img=img.permute(1,2,0)
elif bz==1 and c > 3: # multiple feature maps
img = img[:,0:3,:,:]
img = img.permute(0, 2, 3, 1)[:]
print('warning: more than 3 channels! only channels 0,1,2 are preserved!')
elif bz > 1 and c == 1: # multiple grayscale images
img=img.squeeze()
elif bz > 1 and c == 3: # multiple RGB images
img = img.permute(0, 2, 3, 1)
elif bz > 1 and c > 3: # multiple feature maps
img = img[:,0:3,:,:]
img = img.permute(0, 2, 3, 1)[:]
print('warning: more than 3 channels! only channels 0,1,2 are preserved!')
else:
raise Exception("unsupported type! " + str(img.size()))
elif img.dim()==3: # 3D tensor
bz = 1
c = img.shape[0]
if c == 1: # grayscale
img=img.squeeze()
elif c == 3: # RGB
img = img.permute(1, 2, 0)
else:
raise Exception("unsupported type! " + str(img.size()))
elif img.dim()==2:
pass
else:
raise Exception("unsupported type! "+str(img.size()))
img = img.numpy() # convert to numpy
img = img.squeeze()
if bz ==1:
plt.imshow(img, cmap='gray')
# plt.colorbar()
# plt.show()
else:
for idx in range(0,bz):
plt.subplot(int(bz**0.5),int(np.ceil(bz/int(bz**0.5))),int(idx+1))
plt.imshow(img[idx], cmap='gray')
else:
raise Exception("unsupported type: "+str(type(img)))
【讨论】:
【参考方案2】:使用 fastai 的 show_image
from fastai.vision.all import show_image
【讨论】:
【参考方案3】:处理图像数据的 PyTorch 模块需要 C × H × W 格式的张量。1 而 PILlow 和 Matplotlib 期望图像数组格式为 H × W × C.2
您可以使用 TorchVision 转换轻松地将张量转换为 /from 这种格式:
from torchvision import transforms.functional as F
F.to_pil_image(image_tensor)
或者通过直接排列轴:
image_tensor.permute(1,2,0)
处理图像数据的 PyTorch 模块需要将张量布局为 C × H × W :分别为通道、高度和宽度。
注意我们必须如何使用
Deep Learning with PyTorchpermute
将轴的顺序从 C × H × W 更改为 H × W × C 以匹配 Matplotlib 的预期.
【讨论】:
【参考方案4】:假设图像按照描述加载并存储在变量image
中:
plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
#transforms.ToPILImage()(image).show() # Alternatively
或者Soumith suggested:
def show(img): npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
【讨论】:
import torchvision.transforms # 可能会将导入添加到代码中【参考方案5】:给定图像路径名img_path
的完整示例:
from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")
请注意,transforms.*
返回 class,这就是时髦的括号括起来的原因。
【讨论】:
【参考方案6】:给定一个代表图像的Tensor
,使用.permute()
将通道作为最后一个维度:
plt.imshow( tensor_image.permute(1, 2, 0) )
注意:permute
does not copy or allocate memory 和 from_numpy()
doesn't either.
【讨论】:
哇,谢谢...这对我有用...我试图做 tensor_image.numpy().reshape([224,224,3]) 并使用 cv2.imshow() 可视化它但是我没有得到实际的图像...这里出了什么问题?? @DevashishPrasad 问题是reshape([224,224,3])
做的事情和permute(1, 2, 0)
做的不一样。 permute
函数类似于转置矩阵,其中行变为列,列变为行。 reshape
函数做了一些完全不相关的事情,我不知道如何简洁地描述。简而言之,reshape
是错误的函数。【参考方案7】:
正如您所见,matplotlib
即使没有转换为 numpy
数组也能正常工作。但是 PyTorch 张量(“图像张量”)是通道优先的,因此要将它们与 matplotlib
一起使用,您需要对其进行重塑:
代码:
from scipy.misc import face
import matplotlib.pyplot as plt
import torch
np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)
# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)
# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)
plt.imshow(tensor_image)
plt.show()
输出:
<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
【讨论】:
嗯,对我不起作用,请参阅张量形状的更新问题。以上是关于如何在 PyTorch 中显示单个图像?的主要内容,如果未能解决你的问题,请参考以下文章
如何使用 plt.imshow 和 torchvision.utils.make_grid 在 PyTorch 中生成和显示图像网格?