Pytorch 错误:ValueError:图片应该是 2/3 维。有4个维度[关闭]

Posted

技术标签:

【中文标题】Pytorch 错误:ValueError:图片应该是 2/3 维。有4个维度[关闭]【英文标题】:Pytorch error: ValueError: pic should be 2/3 dimensional. Got 4 dimensions [closed] 【发布时间】:2021-01-29 12:07:07 【问题描述】:

尝试学习本教程here。虽然当我尝试使用 imshow() 函数时选择我的内容图像和样式图像时,我收到此错误:

ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

使用谷歌我并没有真正找到解决这个问题的任何方法。

这是我的代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms 
import torchvision.models as models
import copy
import numpy as np

# This detects if cuda is available for GPU training otherwise will use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Desired size of the output image
imsize = 512 if torch.cuda.is_available() else 256
print(imsize)

# Helper function
def image_loader(image_name, imsize):
    # Scale the imported image and transform it into a torch tensor
    loader = transforms.Compose([transforms.Resize(imsize), transforms.ToTensor()])
    image = Image.open(image_name)
    # Fake batch dimension required to fit network's input dimension
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

# Helper function to show the tensor as a PIL image
def imshow(tensor, title=None):
    unloader = transforms.ToPILImage()
    image = tensor.cpu().clone()
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # Pause so that the plots are updated

# Loading of images
image_directory = './images/'
style_img = image_loader(image_directory + "pb.jpg", imsize)
content_img = image_loader(image_directory + "content.jpg", imsize)
assert style_img.size() == content_img.size(), "we need to import style and content images of the same size"

plt.figure()
imshow(style_img, title='style image')

任何建议都会很有帮助。

这里是样式和内容图片供参考:

【问题讨论】:

【参考方案1】:

matplotlib.pyplot 需要 2D(灰度,dimensions=(W,H))或 3D(彩色,dimensions = (W,H,color channel))在imshow-函数中。

您可能仍然将批量大小作为张量中的第一维,因为在您的代码中您这样做:

# Fake batch dimension required to fit network's input dimension
image = loader(image).unsqueeze(0)

它添加了第一个维度。如果是这样,请尝试使用:

plt.imshow(np.squeeze(image))

plt.imshow(image[0])

【讨论】:

以上是关于Pytorch 错误:ValueError:图片应该是 2/3 维。有4个维度[关闭]的主要内容,如果未能解决你的问题,请参考以下文章

ValueError:pytorch 中的“str”python 维度太多

如何修复 Pytorch-Forecasting 模型拟合关于序列元素的 ValueError

Pytorch ValueError:优化器得到一个空的参数列表

Pytorch 闪电指标:ValueError:preds 和 target 必须具有相同数量的维度,或者 preds 的一个额外维度

PyTorch:将预训练模型从 3 个 RGB 通道更改为 4 个通道后,出现“ValueError:无法优化非叶张量”

ValueError:预期输入 batch_size (59) 与目标 batch_size (1) 匹配