损失函数
Posted liujianing
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了损失函数相关的知识,希望对你有一定的参考价值。
1、torch.nn.CrossEntropyLoss()
loss_func=torch.nn.CrossEntropyLoss()
loss=loss_func(input_data,input_target)
其中input_data的shape一般是(batch_size,output_features),而input_target的shape是(batch_size)
返回的loss是一个张量,但是只有一个数,代表的是计算结果的交叉商损失值
交叉商的计算方法是:
将输入的数据在最后一个维度上做softmax运算
对softmax后的数据取log,注意softmax后所有的数值介于0和1之间,所以log后所有的数值全都是负数
softmax_loged_data=torch.log(torch.nn.Softmax(dim=-1)(input_data))
根据标签对应的数值去softmax_loged_data中索引出相应的数值并且去掉符号,
将这batch_size个数值相加取平均后就是input_data与input_target的交叉商损失值
以上是关于损失函数的主要内容,如果未能解决你的问题,请参考以下文章