pytorch对象在保存图像时对于数组来说太深了

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch对象在保存图像时对于数组来说太深了相关的知识,希望对你有一定的参考价值。

我正在尝试从以下github rep运行代码:

https://github.com/iamkrut/image_inpainting_resnet_unet

我没有更改代码中的任何内容,当代码尝试保存图像时,它导致ValueError,对象太深。错误似乎来自这两行。

images = img_tensor.cpu().detach().permute(0,2,3,1)
plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:3])

这里是错误说明

  File "train.py", line 205, in <module>
    data_dir=args.data_dir)
  File "train.py", line 94, in train_net
    plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:]);
  File "C:ProgramDataAnaconda3envs	orch2libsite-packagesmatplotlibpyplot.py", line 2140, in imsave
    return matplotlib.image.imsave(fname, arr, **kwargs)
  File "C:ProgramDataAnaconda3envs	orch2libsite-packagesmatplotlibimage.py", line 1498, in imsave
    _png.write_png(rgba, fname, dpi=dpi)
ValueError: object too deep for desired array

任何人都知道可能是什么原因或如何解决?谢谢

答案
matplotlib软件包无法理解pytorch数据类型(张量)。您应该将张量数组转换为numpy数组,然后使用matplotlib函数。

a = torch.rand(10, 3, 20, 20) plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1)[0, ...]) # Error plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1).numpy()[0, ...])

另一答案
我设法通过将行更改为来修复代码

images=img_tensor.cpu().numpy()[0] images = np.transpose(images, (1,2,0)) plt.imsave(join(data_dir, 'samples', image), images)

仍不确定先前版本有什么问题。因此,如果有人知道,请告诉我。

以上是关于pytorch对象在保存图像时对于数组来说太深了的主要内容,如果未能解决你的问题,请参考以下文章

使用PyTorch进行数据处理

[PyTorch入门之60分钟入门闪击战]之训练分类器

pytorch学习笔记第五篇——训练分类器

ValueError:对象太深,无法在 optimize.curve_fit 中找到所需数组

裁剪对于 div 来说太大的图像(或其他对象)

CV基础基于Pytorch-Unet训练多类别分割并测试