PyTorch基础(12)-- torch.nn.BatchNorm2d()方法

Posted 奋斗丶

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch基础(12)-- torch.nn.BatchNorm2d()方法相关的知识,希望对你有一定的参考价值。

Batch Normanlization简称BN,也就是数据归一化,对深度学习模型性能的提升有很大的帮助。BN的原理可以查阅我之前的一篇博客。白话详细解读(七)----- Batch Normalization。但为了该篇博客的完整性,在这里简单介绍一下BN。

一、BN的原理

BN的基本思想其实相当直观:因为深层神经网络在做非线性变换前的激活输入值(就是那个x=WU+B,U是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。BN具体操作流程如下图所示:

二、nn.BatchNorm2d()方法详解

清楚了BN的原理之后,便可以很快速的理解这个方法了。

  • 方法
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  • Parameters

    num_features:图像的通道数,也即(N, C, H, W)中的C的值

    eps:增加至分母上的一个很小的数,为了防止/0情况的发生

    momentum:用来计算平均值和方差的值,默认值为0.1

    affine:一个布尔类型的值,当设置为True的时候,该模型对affine参数具有可学习的能力,默认为True

    track_running_stats:一个布尔类型的值,用于记录均值和方差,当设置为True的时候,模型会跟踪均值和方差,反之,不会跟踪均值和方差

  • Shape

    Input: (N, C, H, W)
    Output: (N, C, H, W)

三、案例分析

import torch.nn as nn
import torch
if __name__ == '__main__':
    bn = nn.BatchNorm2d(3)
    ip = torch.randn(2, 3, 2, 2)
    print(ip)
    output = bn(ip)
    print(output)
  • 运行结果

以上是关于PyTorch基础(12)-- torch.nn.BatchNorm2d()方法的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch系列-28]:神经网络基础 - torch.nn模块功能列表

PyTorch基础(13)-- torch.nn.Unfold()方法

PyTorch基础(13)-- torch.nn.Unfold()方法

[Pytorch系列-30]:神经网络基础 - torch.nn库五大基本功能:nn.Parameternn.Linearnn.functioinalnn.Modulenn.Sequentia(代码片

pytorch中的顺序容器——torch.nn.Sequential

pytorch基础