如何在pytorch多类问题的交叉熵损失中设置目标

Posted

技术标签:

【中文标题】如何在pytorch多类问题的交叉熵损失中设置目标【英文标题】:How to set target in cross entropy loss for pytorch multi-class problem 【发布时间】:2020-09-06 08:10:42 【问题描述】:

问题说明:我有一张图片,图片的一个像素只能属于Band5','Band6', 'Band7' 之一(详情见下文)。因此,我有一个 pytorch 多类问题,但我无法理解如何设置需要采用 [batch, w, h] 形式的目标

我的数据加载器返回两个值:

x = chips.loc[:, :, :, self.input_bands]     
y = chips.loc[:, :, :, self.output_bands]        
x = x.transpose('chip','channel','x','y')
y_ohe = y.transpose('chip','channel','x','y')

另外,我已经定义了:

input_bands = ['Band1','Band2', 'Band3', 'Band3', 'Band4']  # input classes
output_bands = ['Band5','Band6', 'Band7'] #target classes

model = ModelName(num_classes = 3, depth=default_depth, in_channels=5, merge_mode='concat').to(device)
loss_new = nn.CrossEntropyLoss()

在我的训练函数中:

        #get values from dataloader
        X = normalize_zero_to_one(X) #input
        y = normalize_zero_to_one(y) #target

        images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
        masks = Variable(torch.from_numpy(y)).to(device) 
        optim.zero_grad()        
        outputs = model(images) 

        loss = loss_new(outputs, masks) # (preds, target)
        loss.backward()         
        optim.step() # Update weights  

我知道目标(这里是masks)应该是[batch_size, w, h]。不过,目前是[batch_size, channels, w, h]

我读了很多帖子,包括1、2,他们说的是the target should only contain the target class indices。我不明白如何连接三个类的索引并仍然将目标设置为[batch_size, w, h]

现在,我得到了错误:

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

据我所知,我不需要做任何热编码。我在网上找到的类似错误和解释在这里:'

Reference 1 Reference 2 Reference 3 Reference 4

任何帮助将不胜感激!谢谢。

【问题讨论】:

【参考方案1】:

如果我理解正确,您当前的“目标”是 [batch_size, channels, w, h]channels==3,因为您有三个可能的目标。 您的目标中的代表什么?您基本上每个像素都有一个 3 向量目标 - 这些是预期的类概率吗?它们是表示正确“波段”的“单热向量”吗? 如果是这样,您只需将argmax 沿目标通道维度获取即可获得目标索引:

proper_target = torch.argmax(masks, dim=1)  # make sure keepdim=False
loss = loss_new(outputs, proper_target)

【讨论】:

以上是关于如何在pytorch多类问题的交叉熵损失中设置目标的主要内容,如果未能解决你的问题,请参考以下文章

如何在 PyTorch 中计算自举交叉熵损失?

当目标不是单热时,如何计算 Pytorch 中 2 个张量之间的正确交叉熵?

详解pytorch中的交叉熵损失函数nn.BCELoss()nn.BCELossWithLogits(),二分类任务如何定义损失函数,如何计算准确率如何预测

为啥训练多类语义分割的unet模型中的分类交叉熵损失函数非常高?

交叉熵损失 Pytorch

为啥我不能将交叉熵损失用于多标签?