在 unet 架构中使用自定义权重图的正确方法

Posted

技术标签:

【中文标题】在 unet 架构中使用自定义权重图的正确方法【英文标题】:Correct way to use custom weight maps in unet architecture 【发布时间】:2020-02-11 03:27:12 【问题描述】:

u-net 架构中有一个著名的技巧是使用自定义权重图来提高准确性。以下是它的详细信息:

现在,通过在这里和其他多个地方询问,我了解了两种方法。我想知道哪个是正确的,或者还有其他更正确的方法吗?

    首先是在训练循环中使用torch.nn.Functional方法:

    loss = torch.nn.functional.cross_entropy(output, target, w) 其中 w 将是计算出的自定义重量。

    二是在训练循环外调用损失函数时使用reduction='none' criterion = torch.nn.CrossEntropy(reduction='none')

    然后在训练循环中乘以自定义权重:

    gt # Ground truth, format torch.long
    pd # Network output
    W # per-element weighting based on the distance map from UNet
    loss = criterion(pd, gt)
    loss = W*loss # Ensure that weights are scaled appropriately
    loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
    loss = torch.mean(loss) # Average across a batch
    

现在,我有点困惑哪一个是对的,还是有其他方法,或者两者都是对的?

【问题讨论】:

【参考方案1】:

加权部分看起来只是简单的加权交叉熵,它对类的数量(在下面的示例中为 2)执行这样的操作。

weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)

编辑:

你见过帕特里克·布莱克的this implementation吗?

# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()

【讨论】:

这里的重量是由某个函数计算出来的,并不谨慎。更多信息,这里有一篇论文-arxiv.org/abs/1505.04597 @Mark 哦,我现在明白了。所以这是一个像素级的损失输出。并且使用诸如opencv之类的库预先计算边界,然后为每个图像保存这些像素位置,然后在训练期间乘以损失张量,以便算法专注于减少这些区域的损失。 谢谢。这个合法的看起来像一个答案,我会尝试更多地验证和实施它,然后会接受你的答案。 你能解释一下这条线背后的直觉logp = logp.gather(1, target.view(batch_size, 1, H, W))【参考方案2】:

请注意,torch.nn.CrossEntropyLoss() 是一个调用 torch.nn.functional 的类。 见https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html#CrossEntropyLoss

您可以在定义条件时使用权重。在功能上比较它们,这两种方法是相同的。

现在,我不明白您在方法 1 的训练循环内和方法 2 的训练循环外计算损失的想法。如果您在循环外计算损失,那么您将如何反向传播?

【讨论】:

我对使用torch.nn.CrossEntropyLoss() torch.nn.functional.cross_entropy(output, target, w) 并不感到困惑,我对如何在损失中使用自定义权重图感到困惑。请参阅这篇论文-arxiv.org/abs/1505.04597 并告诉我,如果你仍然无法弄清楚我在问什么 如果我理解正确的话,我认为方法2是正确的。损失 torch.nn.functional.cross_entropy(output, target, w) 中的权重 (w) 是公式中非 w(x) 类的权重。我们可以用一个小脚本轻松测试它。 是的,即使我得出了相同的结论。如果我的网络按预期运行,我会回复你,并将答案标记为已接受。 好吧,它不工作。当我运行 loss = loss*w 方法时,我得到grad can be implicitly created only for scalar outputs 你确定你是在总结还是取平均值?

以上是关于在 unet 架构中使用自定义权重图的正确方法的主要内容,如果未能解决你的问题,请参考以下文章

keras中的加权mse自定义损失函数 - 自定义权重

LayUi创建一个自定义通用模块

使用 CSS 修复自定义字体行高

百度地图聚合功能自定义聚合文字

office2010的PPT怎么自定义播放

如何在pytorch中获取自定义损失函数的权重?