使用 PyTorch 的交叉熵损失函数是不是需要 One-Hot Encoding?

Posted

技术标签:

【中文标题】使用 PyTorch 的交叉熵损失函数是不是需要 One-Hot Encoding?【英文标题】:Is One-Hot Encoding required for using PyTorch's Cross Entropy Loss Function?使用 PyTorch 的交叉熵损失函数是否需要 One-Hot Encoding? 【发布时间】:2020-10-08 21:34:14 【问题描述】:

例如,如果我想解决 MNIST 分类问题,我们有 10 个输出类。对于 PyTorch,我想使用 torch.nn.CrossEntropyLoss 函数。我是否必须对目标进行格式化以便对其进行一次热编码,或者我可以简单地使用数据集附带的类标签吗?

【问题讨论】:

【参考方案1】:

nn.CrossEntropyLoss 需要整数标签。它在内部所做的是,它根本不会对类标签进行一次性编码,而是使用标签来索引输出概率向量来计算损失,如果你决定使用这个类作为最终标签。这个小而重要的细节使计算损失更容易,并且是执行单热编码的等效操作,测量每个输出神经元的输出损失,因为输出层中的每个值都为零,除了目标类索引的神经元.因此,如果您已经提供了标签,则无需一次性对数据进行编码。

文档对此有更深入的了解:https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html。在文档中,您将看到 targets,它作为输入参数的一部分。这些是您的标签,它们被描述为:

这清楚地显示了输入应该如何塑造以及预期的内容。如果您实际上想对数据进行一次热编码,则需要使用torch.nn.functional.one_hot。为了最好地复制交叉熵损失在幕后所做的事情,您还需要 nn.functional.log_softmax 作为最终输出,并且您必须另外编写自己的损失层,因为 PyTorch 层都没有使用 log softmax 输入和一个-热编码目标。但是,nn.CrossEntropyLoss 将这两种操作结合​​在一起,如果您的输出只是简单的类标签,因此无需进行转换,这是首选。

【讨论】:

所以你的意思是如果我使用 CE 损失。目标输入是标签值。无需使用 one-hot 编码。对吗? @LoaySharaky 是的。为了阐明这一点,假设在您的批次中,您的输入张量为N x D,其中N 是批次大小,D 是单个示例的维度。目标应该只是大小为N 的一维张量,其中值可以从0C - 1C 是类的总数。但是,您的预测值的输出层的形状应该是N x C。因此,损失函数会将目标中的标签作为索引,直接访问输出层张量中的值来计算损失。 不确定这是否发生了变化,但nn.NLLLoss 不接受本评论中声称的单热编码向量作为目标。 @Aydo 是的,现在已经改变了。感谢您的评论。【参考方案2】:

如果您正在加载 ImageLoader 以从文件夹本身加载数据集,那么 PyTorch 将自动为您标记它们。您所要做的就是像这样构建文件夹:

|
|__train
|   |
|   |__1
|   |_ 2
|   |_ 3
|   .
|   .
|   .
|   |_10
|
|__test
    |
    |__1
    |_ 2
    |_ 3
    .
    .
    .
    |_10

每个班级都应该有一个单独的文件夹。如果您从 DataFrame 加载数据,您可以使用以下代码对其进行编码:

one_hot = torch.nn.functional.one_hot(target)

【讨论】:

我不认为 OP 在询问如何加载数据或如何将标签转换为一种热编码形式。他们特别询问 CE 损失是否需要热编码的标签。 @akshayk07 完全正确 OP 想知道是否可以将标签提供给 PyTorch 中的交叉熵损失函数,而无需一次性编码。大概他们已经准备好了标签,想知道这些标签是否可以直接插入到函数中。 OP 不想知道如何进行一次热编码,所以这并不能真正回答问题。

以上是关于使用 PyTorch 的交叉熵损失函数是不是需要 One-Hot Encoding?的主要内容,如果未能解决你的问题,请参考以下文章

详解pytorch中的交叉熵损失函数nn.BCELoss()nn.BCELossWithLogits(),二分类任务如何定义损失函数,如何计算准确率如何预测

Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解

Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解

Pytorch - 使用一种热编码和 softmax 的(分类)交叉熵损失

PyTorch 交叉熵损失函数内部原理简单实现

pytorch二元交叉熵损失函数 nn.BCELoss() 与 torch.nn.BCEWithLogitsLoss()