MNIST、torchvision 中的输出和广播形状不匹配
Posted
技术标签:
【中文标题】MNIST、torchvision 中的输出和广播形状不匹配【英文标题】:Output and Broadcast shape mismatch in MNIST, torchvision 【发布时间】:2019-08-03 02:21:24 【问题描述】:在 Torchvision 中使用 MNIST 数据集时出现以下错误
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
这是我的代码:
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
images, labels = next(iter(trainloader))
【问题讨论】:
MNIST
数据集只有 1 个通道。您需要更改transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
(用于3个频道)
【参考方案1】:
错误是由于数据集上的颜色与灰度,数据集是灰度的。
我通过将转换更改为来修复它
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
【讨论】:
感谢它的工作。你能简单地解释一下你的解决方案吗?为什么会这样? transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 3 个值表示 3 个通道,但是对于 mnist 只有 1 个通道,因此 transforms.Normalize(( 0.5,), (0.5,)) 谢谢吉斌,有道理以上是关于MNIST、torchvision 中的输出和广播形状不匹配的主要内容,如果未能解决你的问题,请参考以下文章
[Pytorch系列-33]:数据集 - torchvision与MNIST数据集
【Pytorch+torchvision】MNIST手写数字识别(代码附最详细注释)
pytorch土堆pytorch教程学习torchvision 中的数据集的使用