PyTorch学习笔记:nn.MSELoss——MSE损失
Posted 视觉萌新、
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch学习笔记:nn.MSELoss——MSE损失相关的知识,希望对你有一定的参考价值。
PyTorch学习笔记:nn.MSELoss——MSE损失
torch.nn.MSELoss(size_average = None,reduce = None,reduction = 'mean')
功能:创建一个平方误差(MSE)损失函数,又称为L2损失:
l
(
x
,
y
)
=
L
=
l
1
,
…
,
l
N
T
,
l
n
=
(
x
n
−
y
n
)
2
l(x,y)=L=\\l_1,\\dots,l_N\\^T,l_n=(x_n-y_n)^2
l(x,y)=L=l1,…,lNT,ln=(xn−yn)2
其中,
N
N
N表示batch size。
函数图像:
输入:
size_average
与reduce
已经被弃用,具体功能可由reduction
替代reduction
:指定损失输出的形式,有三种选择:none
|mean
|sum
。none
:损失不做任何处理,直接输出一个数组;mean
:将得到的损失求平均值再输出,会输出一个数;sum
:将得到的损失求和再输出,会输出一个数
注意:
- 输入的 x x x与 y y y可以是任意维数的数组,但是二者形状必须一致
代码案例
对比reduction
不同时,输出损失的差异
import torch.nn as nn
import torch
x = torch.rand(10, dtype=torch.float)
y = torch.rand(10, dtype=torch.float)
mse_none = nn.MSELoss(reduction='none')
mse_mean = nn.MSELoss(reduction='mean')
mse_sum = nn.MSELoss(reduction='sum')
out_none = mse_none(x, y)
out_mean = mse_mean(x, y)
out_sum = mse_sum(x, y)
print(x)
print(y)
print(out_none)
print(out_mean)
print(out_sum)
输出
# 用于输入的x
tensor([0.4138, 0.1747, 0.9259, 0.2938, 0.5557, 0.9708, 0.0649, 0.6155, 0.3192, 0.1918])
# 用于输入的y
tensor([0.1024, 0.9160, 0.8386, 0.0783, 0.1479, 0.9933, 0.8791, 0.4219, 0.7586, 0.2212])
# 当reduction设置为none时,输出一个数组
# 该数组上的元素为x,y对应每个元素的平方误差损失,即对应元素做差求平方
tensor([9.6983e-02, 5.4955e-01, 7.6214e-03, 4.6433e-02, 1.6630e-01, 5.0293e-04, 6.6287e-01, 3.7512e-02, 1.9310e-01, 8.6344e-04])
# 当reduction设置为mean时,输出所有损失的平均值
tensor(0.1762)
# 当reduction设置为sum时,输出所有损失的和
tensor(1.7617)
注:绘图程序
import torch.nn as nn
import torch
import numpy as np
import matplotlib.pyplot as plt
loss = nn.MSELoss(reduction='none')
x = torch.tensor([0]*100)
y = torch.from_numpy(np.linspace(-3,3,100))
loss_value = loss(x,y)
plt.plot(y, loss_value)
plt.savefig('MSELoss.jpg')
官方文档
nn.MSELoss:https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
初步完稿于:2022年1月29日
以上是关于PyTorch学习笔记:nn.MSELoss——MSE损失的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch常用损失函数nn.BCEloss();nn.BCEWithLogitsLoss();nn.CrossEntropyLoss();nn.L1Loss(); nn.MSELoss();(代码