PyTorch学习笔记:nn.MSELoss——MSE损失

Posted 视觉萌新、

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch学习笔记:nn.MSELoss——MSE损失相关的知识,希望对你有一定的参考价值。

PyTorch学习笔记:nn.MSELoss——MSE损失

torch.nn.MSELoss(size_average = Nonereduce = None,reduction = 'mean')

功能:创建一个平方误差(MSE)损失函数,又称为L2损失:
l ( x , y ) = L = l 1 , … , l N T , l n = ( x n − y n ) 2 l(x,y)=L=\\l_1,\\dots,l_N\\^T,l_n=(x_n-y_n)^2 l(x,y)=L=l1,,lNT,ln=(xnyn)2
其中, N N N表示batch size。

函数图像:

输入:

  • size_averagereduce已经被弃用,具体功能可由reduction替代
  • reduction:指定损失输出的形式,有三种选择:none|mean|sumnone:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数

注意:

  • 输入的 x x x y y y可以是任意维数的数组,但是二者形状必须一致

代码案例

对比reduction不同时,输出损失的差异

import torch.nn as nn
import torch

x = torch.rand(10, dtype=torch.float)
y = torch.rand(10, dtype=torch.float)
mse_none = nn.MSELoss(reduction='none')
mse_mean = nn.MSELoss(reduction='mean')
mse_sum = nn.MSELoss(reduction='sum')
out_none = mse_none(x, y)
out_mean = mse_mean(x, y)
out_sum = mse_sum(x, y)
print(x)
print(y)
print(out_none)
print(out_mean)
print(out_sum)

输出

# 用于输入的x
tensor([0.4138, 0.1747, 0.9259, 0.2938, 0.5557, 0.9708, 0.0649, 0.6155, 0.3192, 0.1918])
# 用于输入的y
tensor([0.1024, 0.9160, 0.8386, 0.0783, 0.1479, 0.9933, 0.8791, 0.4219, 0.7586, 0.2212])
# 当reduction设置为none时,输出一个数组
# 该数组上的元素为x,y对应每个元素的平方误差损失,即对应元素做差求平方
tensor([9.6983e-02, 5.4955e-01, 7.6214e-03, 4.6433e-02, 1.6630e-01, 5.0293e-04, 6.6287e-01, 3.7512e-02, 1.9310e-01, 8.6344e-04])
# 当reduction设置为mean时,输出所有损失的平均值
tensor(0.1762)
# 当reduction设置为sum时,输出所有损失的和
tensor(1.7617)

注:绘图程序

import torch.nn as nn
import torch
import numpy as np
import matplotlib.pyplot as plt

loss = nn.MSELoss(reduction='none')
x = torch.tensor([0]*100)
y = torch.from_numpy(np.linspace(-3,3,100))
loss_value = loss(x,y)
plt.plot(y, loss_value)
plt.savefig('MSELoss.jpg')

官方文档

nn.MSELoss:https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss

初步完稿于:2022年1月29日

以上是关于PyTorch学习笔记:nn.MSELoss——MSE损失的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch常用损失函数nn.BCEloss();nn.BCEWithLogitsLoss();nn.CrossEntropyLoss();nn.L1Loss(); nn.MSELoss();(代码

pytorch torch.nn.MSELoss(size_average=True)(均方误差损失函数)

Pytorch MSELoss

Pytorch MSELoss

神经网络架构pytorch-MSELoss损失函数

PyTorch学习4《PyTorch深度学习实践》——线性回归(Linear Regression)