ValueError:预期输入 batch_size (59) 与目标 batch_size (1) 匹配
Posted
技术标签:
【中文标题】ValueError:预期输入 batch_size (59) 与目标 batch_size (1) 匹配【英文标题】:ValueError: Expected input batch_size (59) to match target batch_size (1) 【发布时间】:2021-09-04 06:43:29 【问题描述】:我正在尝试使用 pytorch 构建语义分割模型。但是,我遇到了这个错误,不知道如何解决。
这是模型:
class SegmentationNN(pl.LightningModule):
def __init__(self, num_classes=23, hparams=None):
super().__init__()
self.hparams = hparams
self.model=models.alexnet(pretrained=True).features
self.conv=nn.Conv2d(256, 3, kernel_size=1)
self.upsample = nn.Upsample(size=(240,240))
def forward(self, x):
print('Input:', x.shape)
x = self.model(x)
print('After Alexnet convs:', x.shape)
x = self.conv(x)
print('After 1-conv:', x.shape)
x = self.upsample(x)
print('After upsampling:', x.shape)
return x
def training_step(self, batch, batch_idx):
images, targets = batch
# targets = targets.view(targets.size(0), -1)
out = self.forward(images)
loss_func = nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')
loss = loss_func(out, targets.unsqueeze(0))
tensorboard_logs = 'loss': loss
return 'loss': loss, 'log':tensorboard_logs
def validation_step(self, batch, batch_idx):
images, targets = batch
# targets = targets.view(targets.size(0), -1)
out = self.forward(images)
loss_func = nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')
loss = loss_func(out, targets.unsqueeze(0))
tensorboard_logs = 'loss': loss
return 'loss': loss, 'log':tensorboard_logs
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=self.hparams['learning_rate'])
return optim
这就是训练和合身:
train_dataloader = DataLoader(train_data, batch_size=hparams['batch_size'])
val_dataloader = DataLoader(val_data, batch_size=hparams['batch_size'])
trainer = pl.Trainer(
max_epochs=50,
gpus=1 if torch.cuda.is_available() else None
)
pass
trainer.fit(model, train_dataloader, val_dataloader)
这些是每一层之后张量的大小:
Input: torch.Size([59, 3, 240, 240])
After Alexnet convs: torch.Size([59, 256, 6, 6])
After 1-conv: torch.Size([59, 3, 6, 6])
After upsampling: torch.Size([59, 3, 240, 240])
我是 Pytorch 和 Pytorch Lightning 的初学者,所以每一个建议都会得到赞赏!
【问题讨论】:
【参考方案1】:你能在这里删除 unsqueeze(0) 部分吗:loss = loss_func(out, targets.unsqueeze(0))
【讨论】:
以上是关于ValueError:预期输入 batch_size (59) 与目标 batch_size (1) 匹配的主要内容,如果未能解决你的问题,请参考以下文章
预期的输入batch_size以匹配目标batch_size(11)
ValueError: 标签形状必须为 [batch_size, labels_dimension],得到 (128, 2)
ValueError:输入 0 与层 lstm_1 不兼容:预期 ndim=3,发现 ndim=2 [keras]