pytorch对象在保存图像时对于数组来说太深了
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch对象在保存图像时对于数组来说太深了相关的知识,希望对你有一定的参考价值。
我正在尝试从以下github rep运行代码:
https://github.com/iamkrut/image_inpainting_resnet_unet
我没有更改代码中的任何内容,当代码尝试保存图像时,它导致ValueError,对象太深。错误似乎来自这两行。
images = img_tensor.cpu().detach().permute(0,2,3,1)
plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:3])
这里是错误说明
File "train.py", line 205, in <module>
data_dir=args.data_dir)
File "train.py", line 94, in train_net
plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:]);
File "C:ProgramDataAnaconda3envs orch2libsite-packagesmatplotlibpyplot.py", line 2140, in imsave
return matplotlib.image.imsave(fname, arr, **kwargs)
File "C:ProgramDataAnaconda3envs orch2libsite-packagesmatplotlibimage.py", line 1498, in imsave
_png.write_png(rgba, fname, dpi=dpi)
ValueError: object too deep for desired array
任何人都知道可能是什么原因或如何解决?谢谢
答案
matplotlib软件包无法理解pytorch数据类型(张量)。您应该将张量数组转换为numpy数组,然后使用matplotlib函数。a = torch.rand(10, 3, 20, 20)
plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1)[0, ...]) # Error
plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1).numpy()[0, ...])
另一答案
我设法通过将行更改为来修复代码images=img_tensor.cpu().numpy()[0]
images = np.transpose(images, (1,2,0))
plt.imsave(join(data_dir, 'samples', image), images)
仍不确定先前版本有什么问题。因此,如果有人知道,请告诉我。
以上是关于pytorch对象在保存图像时对于数组来说太深了的主要内容,如果未能解决你的问题,请参考以下文章