Pytorch 语义分割损失函数
Posted
技术标签:
【中文标题】Pytorch 语义分割损失函数【英文标题】:Pytorch semantic segmentation loss function 【发布时间】:2021-07-30 18:11:20 【问题描述】:我是细分模型的新手。
我想使用 deeplabv3_resnet50 模型。
我的图像形状为(256, 256, 3)
,我的标签形状为(256, 256)
。我标签中的每个像素都有一个类值(0-4)。 DataLoader 中设置的批大小为 32。
因此,我的输入批次的形状是[32, 3, 256, 256]
,对应目标的形状是[32, 256, 256]
。我相信这是正确的。
我尝试使用nn.BCEWithLogitsLoss()
。
-
对于我的情况,这是正确的损失函数吗?或者我应该使用
改为
CrossEntropy
?
如果这是正确的,我的模型的输出是[32, 5, 256, 256]
。每个图像预测的形状为[5,256, 256]
,第 0 层是否表示第 0 类的非标准化概率?为了使 [32, 256, 256]
张量与目标匹配以馈入 BCEWithLogitsLoss
,我是否需要将未标准化的概率转换为类?
如果我应该使用CrossEntropy
,我的输出和标签的大小应该是多少?
谢谢大家。
【问题讨论】:
【参考方案1】:你使用了错误的损失函数。
nn.BCEWithLogitsLoss()
代表 Binary 交叉熵损失:这是 Binary 标签的损失。在您的情况下,您有 5 个标签 (0..4)。
您应该使用nn.CrossEntropyLoss
:为离散标签设计的损失,超出二进制情况。
您的模型应该输出一个形状为 [32, 5, 256, 256]
的张量:对于该批次的 32 个图像中的每个像素,它应该输出一个 logits 的 5 维向量。 logits 是每个类的“原始”分数,稍后将使用 softmax 函数将其归一化为类概率。
为了数值稳定性和计算效率,nn.CrossEntropyLoss
不需要您显式计算 logits 的 softmax,而是在内部为您完成。如文档所述:
此标准将 LogSoftmax 和 NLLLoss 组合在一个类中。
【讨论】:
知道了。如果我想稍后计算 IOU 或像素精度,我是否应该将输出设为[32, 256, 256]
(可能是 output.argmax(dim=1))以匹配我的标签?
@KKKcat argmax 在通道暗淡上应该会给你预测的标签【参考方案2】:
鉴于您正在处理 5 个类,您应该使用 CrossEntropyLoss。顾名思义,二元交叉熵是您在拥有二元分割图时使用的损失函数。
PyTorch 中的 CrossEntropy 函数期望模型的输出具有以下形状 - [batch, num_classes, H, W]
(将其直接传递给您的损失函数),而基本事实的形状为 [batch, H, W]
其中H, W
在您的情况是 256、256。另外请通过在张量上调用 .long()
来确保基本事实是 long
类型
【讨论】:
以上是关于Pytorch 语义分割损失函数的主要内容,如果未能解决你的问题,请参考以下文章
为啥训练多类语义分割的unet模型中的分类交叉熵损失函数非常高?