pytorch Crossentropy 导致无法匹配的批量大小
Posted
技术标签:
【中文标题】pytorch Crossentropy 导致无法匹配的批量大小【英文标题】:pytorch Crossentropy results in unmatched batch size 【发布时间】:2020-04-07 12:42:56 【问题描述】:我正在使用数据加载器加载图像文件夹
图像文件夹由三个类别(标签)组成,即
'/root/ant/dsd.png'
'/root/ant/sfds.png'
...
....
'/root/bee/dsf.png'
....
..
'/root/whey/sfd.png'
这里有蚂蚁、蜜蜂、乳清三类
通过执行上面的代码,我得到了一个错误,输出和目标的 bacth 大小不匹配
错误:预期输入 batch_size (3) 与目标 batch_size (1) 匹配。
我认为错误可能出在 trainloader 中,因为提取了不同形状不匹配的标签
data_transform = transforms.Compose([
transforms.Resize(size=28),
transforms.ToTensor()
])
kumda_dataset = datasets.ImageFolder(root='/content/gdrive/My Drive/Colab Notebooks/images',
transform=data_transform)
#train & test
train_size = int(0.8 * len(kumda_dataset))
test_size = len(kumda_dataset) - train_size
#splitting
train_dataset, test_dataset = torch.utils.data.random_split(kumda_dataset, [train_size, test_size])
trainloader = torch.utils.data.DataLoader(train_dataset , batch_size = 1, shuffle = True)
testloader = torch.utils.data.DataLoader(test_dataset , batch_size = 4, shuffle = False )
model=nn.Linear(784,1)
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=0.01)
num=10
for epoch in range(num):
for i,(images,labels) in enumerate(trainloader):
images=images.reshape(-1,28*28)
output=model(images)
loss=criterion(output,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if(i+1%70==0):
print("Epoch: /, \tIteration: /, \tLoss: ".format(epoch + 1, num, i + 1,len(dataset_loader), loss.item()))
提前感谢您的解决
【问题讨论】:
【参考方案1】:您的线性层只为每个批次项目输出一个值。 CrossEntropyLoss
期望每个类有一个输出维度。改为
model=nn.Linear(784, 3)
因为你有 3 个班级。
【讨论】:
以上是关于pytorch Crossentropy 导致无法匹配的批量大小的主要内容,如果未能解决你的问题,请参考以下文章
python 神经网络损失 = 'categorical_crossentropy' vs 'binary_crossentropy' isse
Keras:binary_crossentropy 和 categorical_crossentropy 混淆
Binary_crossentropy 和 Categorical_crossentropy 之间的混淆
为啥对于 Keras 中的多类分类, binary_crossentropy 比 categorical_crossentropy 更准确?
python Keras sparse_categorical_crossentropy vs categorical_crossentropy