正确使用交叉熵作为元素序列的损失函数
Posted
技术标签:
【中文标题】正确使用交叉熵作为元素序列的损失函数【英文标题】:Correct use of Cross-entropy as a loss function for sequence of elements 【发布时间】:2021-11-20 20:18:50 【问题描述】:我有一个序列标记任务。
因此,作为输入,我有一个形状为 [batch_size, sequence_length]
的元素序列,并且该序列的每个元素都应分配给某个类。
作为训练神经网络期间的损失函数,我使用Cross-entropy。
我应该如何正确使用它?
我的变量target_predictions
的形状为[batch_size, sequence_length, number_of_classes]
,target
的形状为[batch_size, sequence_length]
。
文档说:
我知道如果我使用CrossEntropyLoss(target_predictions.permute(0, 2, 1), target)
,一切都会正常工作。但我担心 torch 会将我的 sequence_length
解释为屏幕截图中的 d_1
变量,并且会认为这是一个多维损失,但事实并非如此。
我应该如何正确做?
【问题讨论】:
【参考方案1】:使用 CE 损失将给您损失而不是标签。默认情况下,将采用您可能追求的平均值,并且带有 permute 的 sn-p 会很好(使用此损失,您可以通过向后训练您的 nn)。
要获得预测的类,只需在适当的维度上取 argmax,在没有排列的情况下,它将是:
labels = torch.argmax(target_predictions, dim=-1)
这将为您提供包含类的 (batch, sequence_length) 输出。
【讨论】:
是的,我使用 CE 作为损失。您认为轴的排列就足够了,pytorch不会因为没有多维损失而与d1
变量混淆?
它不会,只要你坚持适当的维度,它就适用于多维情况(例如分割)。以上是关于正确使用交叉熵作为元素序列的损失函数的主要内容,如果未能解决你的问题,请参考以下文章