我想确认哪些计算 Dice Loss 的方法是正确的

Posted

技术标签:

【中文标题】我想确认哪些计算 Dice Loss 的方法是正确的【英文标题】:I want to confirm which of these methods to calculate Dice Loss is correct 【发布时间】:2021-07-17 16:22:43 【问题描述】:

所以我有 4 种计算骰子损失的方法,其中 3 种返回相同的结果,所以我可以得出结论,其中 1 种计算错误,但我想和你们确认一下:

import torch
torch.manual_seed(0)

inputs = torch.rand((3,1,224,224))
target = torch.rand((3,1,224,224))

方法一:展平张量

def method1(inputs, target):

    inputs = inputs.reshape( -1)

    target = target.reshape( -1)

    intersection = (inputs * target).sum()
    union = inputs.sum() + target.sum()
    dice = (2. * intersection) / (union + 1e-8)
    dice = dice.sum()

    print("method1", dice)

方法2:除batch size外的张量展平,对所有dim求和

def method2(inputs, target):
    num = target.shape[0]
    inputs = inputs.reshape(num, -1)

    target = target.reshape(num, -1)

    intersection = (inputs * target).sum()
    union = inputs.sum() + target.sum()
    dice = (2. * intersection) / (union + 1e-8)
    dice = dice.sum()/num

    print("method2", dice)

方法3:除batch size外的张量扁平化,sum dim 1

def method3(inputs, target):
    num = target.shape[0]
    inputs = inputs.reshape(num, -1)

    target = target.reshape(num, -1)

    intersection = (inputs * target).sum(1)
    union = inputs.sum(1) + target.sum(1)
    dice = (2. * intersection) / (union + 1e-8)
    dice = dice.sum()/num

    print("method3", dice)

方法 4:不要展平张量

def method4(inputs, target):

    intersection = (inputs * target).sum()
    union = inputs.sum() + target.sum()
    dice = (2. * intersection) / (union + 1e-8)


    print("method4", dice)

method1(inputs, target)
method2(inputs, target)
method3(inputs, target)
method4(inputs, target)

方法 1,3 和 4 打印:0.5006 方法2打印:0.1669

这是有道理的,因为我在 3 个维度上将输入和目标展平,忽略了批量大小,然后我将展平产生的所有 2 个维度相加,而不仅仅是暗淡 1

方法 4 似乎是最优化的一种

【问题讨论】:

【参考方案1】:

首先,您需要确定您报告的骰子得分:批次中所有样本的骰子得分(方法 1,2 和 4)或批次中每个样本的平均骰子得分(方法 3)。 如果我没记错的话,您想使用方法 3 - 您想优化批次中每个样本的骰子得分,而不是“全局”骰子得分:假设您在“容易”中有一个“困难”样本“ 批。 “困难”样本的错误分类像素相对于所有其他像素而言可以忽略不计。但是,如果您单独查看每个样本的骰子分数,那么“困难”样本的骰子分数将不可忽略。

【讨论】:

以上是关于我想确认哪些计算 Dice Loss 的方法是正确的的主要内容,如果未能解决你的问题,请参考以下文章

如何将模型测试预测转换为 png

图像分割loss集合

IoU 和 Dice

常用优化器算法归纳介绍

在 Keras 和输入参数数据类型中使用 Earth Mover Loss 方法

训练网络中出现loss等于Nan的情况几种思路