深入浅出PyTorch中的nn.CrossEntropyLoss

Posted aelum

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深入浅出PyTorch中的nn.CrossEntropyLoss相关的知识,希望对你有一定的参考价值。

目录

一、前言

nn.CrossEntropyLoss 常用作多分类问题的损失函数(对交叉熵还不了解的读者可以看我的这篇文章),本文将围绕PyTorch的官方文档对重要知识点进行逐一讲解(不会全部讲解)。

import torch
import torch.nn as nn

二、理论基础

对于 C   ( C > 2 ) C\\,(C>2) C(C>2) 分类问题,先不考虑 batch 的情形,设神经网络的输出(还未经过 Softmax)为 x c c = 1 C \\x_c\\_c=1^C xcc=1C,经过 Softmax 后得到

q i = exp ⁡ ( x i ) ∑ c = 1 C exp ⁡ ( x c ) q_i=\\frac\\exp(x_i)\\sum_c=1^C\\exp(x_c) qi=c=1Cexp(xc)exp(xi)

从而该样本的交叉熵损失为

H ( p , q ) = − ∑ i = 1 C p i log ⁡ q i = − ∑ i = 1 C p i log ⁡ exp ⁡ ( x i ) ∑ c = 1 C exp ⁡ ( x c ) H(p,q)=-\\sum_i=1^C p_i\\log q_i=-\\sum_i=1^C p_i\\log\\frac\\exp(x_i)\\sum_c=1^C\\exp(x_c) H(p,q)=i=1Cpilogqi=i=1Cpilogc=1Cexp(xc)exp(xi)

其中 ( p 1 , p 2 , ⋯   , p C ) (p_1,p_2,\\cdots,p_C) (p1,p2,,pC) 是 One-Hot 向量。

不妨令 p y = 1   ( y ∈ 1 , 2 , ⋯   , C ) p_y=1\\,(y\\in\\1,2,\\cdots,C\\) py=1(y1,2,,C),其余为 0 0 0,因此上式变为

H ( p , q ) = − log ⁡ exp ⁡ ( x y ) ∑ c = 1 C exp ⁡ ( x c ) H(p,q)=-\\log\\frac\\exp(x_y)\\sum_c=1^C\\exp(x_c) H(p,q)=logc=1Cexp(xc)exp(xy)

现在考虑有 batch 的情形,不妨设 batch size 为 N N N,神经网络的输出为 x n c n c ,    n = 1 , ⋯   , N ,    c = 1 , ⋯   , C \\x_nc\\_nc,\\;n=1,\\cdots,N,\\;c=1,\\cdots,C xncnc,n=1,,N,c=1,,C,第 n n n 个样本的真实类别记为 y n   ( y n ∈ 1 , 2 , ⋯   , C ) y_n\\,(y_n\\in\\1,2,\\cdots,C\\) yn(yn1,2,,C),第 n n n 个样本的交叉熵损失记为 l n l_n ln,则仿照上式就有

l n = − log ⁡ exp ⁡ ( x n , y n ) ∑ c = 1 C exp ⁡ ( x n c ) l_n=-\\log \\frac\\exp(x_n,y_n)\\sum_c=1^C\\exp(x_nc) ln=logc=1Cexp(xnc)exp(xn,yn)

接下来我们讨论一些特殊情形。当数据不平衡时(某一类的样本数特别多,另一类的样本数特别少),我们需要为每一类的损失安排一个权重用来平衡。权重为 w = ( w 1 , w 2 , ⋯   , w C ) \\boldsymbolw=(w_1,w_2,\\cdots,w_C) w=(w1,w2,,wC)

📌 模型容易在样本数最多的一个(或几个)类上过拟合,因此对于那些样本数较少的类,我们需要设置更高的权重,这样模型在预测这些类的标签时一旦出错,就会受到更多的惩罚

安排了权重后,相应的损失为

l n = − w y n log ⁡ exp ⁡ ( x n , y n ) ∑ c = 1 C exp ⁡ ( x n c ) l_n=-w_y_n\\log \\frac\\exp(x_n,y_n)\\sum_c=1^C\\exp(x_nc) ln=wynlogc=1Cexp(xnc)exp(xn,yn)

计算完 l 1 , l 2 , ⋯   , l N l_1,l_2,\\cdots,l_N l1,l2,,lN 后,我们既可以一次性将它们全部返回(对应 reduction=none),也可以返回它们的均值(对应 reduction=mean),还可以返回它们的(对应 reduction=sum):

ℓ = ( l 1 , ⋯   , l N ) ,

以上是关于深入浅出PyTorch中的nn.CrossEntropyLoss的主要内容,如果未能解决你的问题,请参考以下文章

翻译: 深入深度学习 2.3. 线性代数 pytorch

最新PyTorch0.4.0教程01PyTorch的动态计算图深入浅出

翻译: 2.7. 如何利用帮助文档 深入神经网络 pytorch

DataWhales深入浅出Pytorch-第二章

DataWhales深入浅出Pytorch-第二章

DataWhales深入浅出Pytorch-第二章