MNIST、torchvision 中的输出和广播形状不匹配

Posted

技术标签:

【中文标题】MNIST、torchvision 中的输出和广播形状不匹配【英文标题】:Output and Broadcast shape mismatch in MNIST, torchvision 【发布时间】:2019-08-03 02:21:24 【问题描述】:

在 Torchvision 中使用 MNIST 数据集时出现以下错误

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

这是我的代码:

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                          ])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
images, labels = next(iter(trainloader))

【问题讨论】:

MNIST 数据集只有 1 个通道。您需要更改transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(用于3个频道) 【参考方案1】:

错误是由于数据集上的颜色与灰度,数据集是灰度的。

我通过将转换更改为来修复它

transform = transforms.Compose([transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])

【讨论】:

感谢它的工作。你能简单地解释一下你的解决方案吗?为什么会这样? transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 3 个值表示 3 个通道,但是对于 mnist 只有 1 个通道,因此 transforms.Normalize(( 0.5,), (0.5,)) 谢谢吉斌,有道理

以上是关于MNIST、torchvision 中的输出和广播形状不匹配的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch系列-33]:数据集 - torchvision与MNIST数据集

如何将pytorch中的标签转换为onehot

【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)

pytorch土堆pytorch教程学习torchvision 中的数据集的使用

PyTorch 和 TorchVision FasterRCNN 解释 C++ GenericDict 中的输出

PyTorch学习笔记 5.torchvision库