深入浅出PyTorch中的nn.CrossEntropyLoss
Posted aelum
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深入浅出PyTorch中的nn.CrossEntropyLoss相关的知识,希望对你有一定的参考价值。
目录
一、前言
nn.CrossEntropyLoss
常用作多分类问题的损失函数(对交叉熵还不了解的读者可以看我的这篇文章),本文将围绕PyTorch的官方文档对重要知识点进行逐一讲解(不会全部讲解)。
import torch
import torch.nn as nn
二、理论基础
对于 C ( C > 2 ) C\\,(C>2) C(C>2) 分类问题,先不考虑 batch 的情形,设神经网络的输出(还未经过 Softmax)为 x c c = 1 C \\x_c\\_c=1^C xcc=1C,经过 Softmax 后得到
q i = exp ( x i ) ∑ c = 1 C exp ( x c ) q_i=\\frac\\exp(x_i)\\sum_c=1^C\\exp(x_c) qi=∑c=1Cexp(xc)exp(xi)
从而该样本的交叉熵损失为
H ( p , q ) = − ∑ i = 1 C p i log q i = − ∑ i = 1 C p i log exp ( x i ) ∑ c = 1 C exp ( x c ) H(p,q)=-\\sum_i=1^C p_i\\log q_i=-\\sum_i=1^C p_i\\log\\frac\\exp(x_i)\\sum_c=1^C\\exp(x_c) H(p,q)=−i=1∑Cpilogqi=−i=1∑Cpilog∑c=1Cexp(xc)exp(xi)
其中 ( p 1 , p 2 , ⋯ , p C ) (p_1,p_2,\\cdots,p_C) (p1,p2,⋯,pC) 是 One-Hot 向量。
不妨令 p y = 1 ( y ∈ 1 , 2 , ⋯ , C ) p_y=1\\,(y\\in\\1,2,\\cdots,C\\) py=1(y∈1,2,⋯,C),其余为 0 0 0,因此上式变为
H ( p , q ) = − log exp ( x y ) ∑ c = 1 C exp ( x c ) H(p,q)=-\\log\\frac\\exp(x_y)\\sum_c=1^C\\exp(x_c) H(p,q)=−log∑c=1Cexp(xc)exp(xy)
现在考虑有 batch 的情形,不妨设 batch size 为 N N N,神经网络的输出为 x n c n c , n = 1 , ⋯ , N , c = 1 , ⋯ , C \\x_nc\\_nc,\\;n=1,\\cdots,N,\\;c=1,\\cdots,C xncnc,n=1,⋯,N,c=1,⋯,C,第 n n n 个样本的真实类别记为 y n ( y n ∈ 1 , 2 , ⋯ , C ) y_n\\,(y_n\\in\\1,2,\\cdots,C\\) yn(yn∈1,2,⋯,C),第 n n n 个样本的交叉熵损失记为 l n l_n ln,则仿照上式就有
l n = − log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-\\log \\frac\\exp(x_n,y_n)\\sum_c=1^C\\exp(x_nc) ln=−log∑c=1Cexp(xnc)exp(xn,yn)
接下来我们讨论一些特殊情形。当数据不平衡时(某一类的样本数特别多,另一类的样本数特别少),我们需要为每一类的损失安排一个权重用来平衡。权重为 w = ( w 1 , w 2 , ⋯ , w C ) \\boldsymbolw=(w_1,w_2,\\cdots,w_C) w=(w1,w2,⋯,wC)。
📌 模型容易在样本数最多的一个(或几个)类上过拟合,因此对于那些样本数较少的类,我们需要设置更高的权重,这样模型在预测这些类的标签时一旦出错,就会受到更多的惩罚
安排了权重后,相应的损失为
l n = − w y n log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-w_y_n\\log \\frac\\exp(x_n,y_n)\\sum_c=1^C\\exp(x_nc) ln=−wynlog∑c=1Cexp(xnc)exp(xn,yn)
计算完
l
1
,
l
2
,
⋯
,
l
N
l_1,l_2,\\cdots,l_N
l1,l2,⋯,lN 后,我们既可以一次性将它们全部返回(对应 reduction=none
),也可以返回它们的均值(对应 reduction=mean
),还可以返回它们的和(对应 reduction=sum
):
ℓ
=
(
l
1
,
⋯
,
l
N
)
,
以上是关于深入浅出PyTorch中的nn.CrossEntropyLoss的主要内容,如果未能解决你的问题,请参考以下文章 最新PyTorch0.4.0教程01PyTorch的动态计算图深入浅出