Pytorch 闪电指标:ValueError:preds 和 target 必须具有相同数量的维度,或者 preds 的一个额外维度

Posted

技术标签:

【中文标题】Pytorch 闪电指标:ValueError:preds 和 target 必须具有相同数量的维度,或者 preds 的一个额外维度【英文标题】:Pytorch lightning metrics: ValueError: preds and target must have same number of dimensions, or one additional dimension for preds 【发布时间】:2021-06-03 01:13:33 【问题描述】:

谷歌搜索这会让你无处可去,所以我决定通过将其发布为可搜索的问题来帮助未来的我和其他人。


def __init__():
    ...
    self.val_acc = pl.metrics.Accuracy()

def validation_step(self, batch, batch_index):
    ...
    self.val_acc.update(log_probs, label_batch)

给予

ValueError: preds and target must have same number of dimensions, or one additional dimension for preds

对于log_probs.shape == (16, 4) 和对于label_batch.shape == (16, 4)

有什么问题?

【问题讨论】:

【参考方案1】:

pl.metrics.Accuracy() 需要一批 dtype=torch.long 标签,而不是一次性编码标签。

因此,应该喂它

self.val_acc.update(log_probs, torch.argmax(label_batch.squeeze(), dim=1))


这和torch.nn.CrossEntropyLoss一样

【讨论】:

以上是关于Pytorch 闪电指标:ValueError:preds 和 target 必须具有相同数量的维度,或者 preds 的一个额外维度的主要内容,如果未能解决你的问题,请参考以下文章

将 pytorch 闪电与香草 pytorch 混合

pytorch闪电模型的输出预测

用 pytorch 闪电组织张量板图

权重和偏差扫描无法使用 pytorch 闪电导入模块

使用 pytorch 闪电的不同测试结果

通过模型检查点时 Pytorch 闪电出错