PyTorch - Sequential和ModuleList

Posted SpikeKing

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch - Sequential和ModuleList相关的知识,希望对你有一定的参考价值。

module.py

Module Class,继承于torch.nn.Module

  • train()函数:train(mode=True),当前self.training=True,所有子模块children,都设置training=True

    • __init__,设置self.training参数,子类会使用training参数

    • Dropout源码,Dropout -> _DropoutNd -> Module、BatchNorm源码,都使用training模式

      • BN: Buffers are only updated if they are to be tracked and we are in training mode.
        
    • super(_DropoutNd, self).__init__()

  • eval()函数:train(mode=Fale)

  • requires_grad_():当前模型的所有参数,module的函数,和parameter的函数,参数计算梯度

  • zero_grad():调用优化器的zero_grad(),将所有的参数的梯度都清0,避免梯度累积,优化器设置zero_grad,不需要调用模型

  • __repr__():魔法函数,string的表示,名称+模块描述

  • __dir__():attrs、parameters、modules、buffers、keys,返回所有键值

module.py,Module Class的源码

container.py

  • Container已经过期
  • Sequential(Module),有序的,直接传入Module的实例,或传入OrderedDict(),包含键值,最常用
  • 或者传入键值,或者键值为默认idx递增序列
s = torch.nn.Sequential(torch.nn.Linear(2,3), torch.nn.Linear(3,4))
s._modules
OrderedDict([('0', Linear(in_features=2, out_features=3, bias=True)),
             ('1', Linear(in_features=3, out_features=4, bias=True))])
  • forward(),input输入module,输出input,循环连续处理input
  • ModuleList(Module),所有子module都放在list中,存放module的列表
  • 将modules添加到ModuleList中,insert或append函数,比list拥有更多的module父类的方法
  • ModuleDict(Module),可以通过key去访问不同的module,本身是一个dict,又是一个module,可以用于module的子module
  • ParameterList(Module),把parameter类型参数放入列表中,mm = matrix multiplication,矩阵乘法
  • ParameterDict(Module),同上
  • Module <-> Parameter,都是container,容器,List和Dict只有存放,没有forward功能,只有Sequential包含forward功能

以上是关于PyTorch - Sequential和ModuleList的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Pytorch 的“nn.Sequential”中展平输入

pytorch中的顺序容器——torch.nn.Sequential

Pytorch学习笔记——Sequential类参数管理与GPU

pytorch教程之nn.Sequential类详解——使用Sequential类来自定义顺序连接模型

pytorch教程之nn.Sequential类详解——使用Sequential类来自定义顺序连接模型

PyTorch 中的 nn.functional() 与 nn.sequential() 之间是不是存在计算效率差异