Pytorch 闪电指标:ValueError:preds 和 target 必须具有相同数量的维度,或者 preds 的一个额外维度
Posted
技术标签:
【中文标题】Pytorch 闪电指标:ValueError:preds 和 target 必须具有相同数量的维度,或者 preds 的一个额外维度【英文标题】:Pytorch lightning metrics: ValueError: preds and target must have same number of dimensions, or one additional dimension for preds 【发布时间】:2021-06-03 01:13:33 【问题描述】:谷歌搜索这会让你无处可去,所以我决定通过将其发布为可搜索的问题来帮助未来的我和其他人。
def __init__():
...
self.val_acc = pl.metrics.Accuracy()
def validation_step(self, batch, batch_index):
...
self.val_acc.update(log_probs, label_batch)
给予
ValueError: preds and target must have same number of dimensions, or one additional dimension for preds
对于log_probs.shape == (16, 4)
和对于label_batch.shape == (16, 4)
有什么问题?
【问题讨论】:
【参考方案1】:pl.metrics.Accuracy()
需要一批 dtype=torch.long
标签,而不是一次性编码标签。
因此,应该喂它
self.val_acc.update(log_probs, torch.argmax(label_batch.squeeze(), dim=1))
这和torch.nn.CrossEntropyLoss
一样
【讨论】:
以上是关于Pytorch 闪电指标:ValueError:preds 和 target 必须具有相同数量的维度,或者 preds 的一个额外维度的主要内容,如果未能解决你的问题,请参考以下文章