PyTorch - Sequential和ModuleList
Posted SpikeKing
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch - Sequential和ModuleList相关的知识,希望对你有一定的参考价值。
module.py
Module Class,继承于torch.nn.Module
-
train()函数:train(mode=True),当前self.training=True,所有子模块children,都设置training=True
-
__init__
,设置self.training参数,子类会使用training参数 -
Dropout源码,Dropout -> _DropoutNd -> Module、BatchNorm源码,都使用training模式
-
BN: Buffers are only updated if they are to be tracked and we are in training mode.
-
-
super(_DropoutNd, self).__init__()
-
-
eval()函数:train(mode=Fale)
-
requires_grad_()
:当前模型的所有参数,module的函数,和parameter的函数,参数计算梯度 -
zero_grad()
:调用优化器的zero_grad()
,将所有的参数的梯度都清0,避免梯度累积,优化器设置zero_grad,不需要调用模型 -
__repr__()
:魔法函数,string的表示,名称+模块描述 -
__dir__()
:attrs、parameters、modules、buffers、keys,返回所有键值
module.py,Module Class的源码
container.py
- Container已经过期
- Sequential(Module),有序的,直接传入Module的实例,或传入OrderedDict(),包含键值,最常用。
- 或者传入键值,或者键值为默认idx递增序列
s = torch.nn.Sequential(torch.nn.Linear(2,3), torch.nn.Linear(3,4))
s._modules
OrderedDict([('0', Linear(in_features=2, out_features=3, bias=True)),
('1', Linear(in_features=3, out_features=4, bias=True))])
- forward(),input输入module,输出input,循环连续处理input
- ModuleList(Module),所有子module都放在list中,存放module的列表
- 将modules添加到ModuleList中,insert或append函数,比list拥有更多的module父类的方法
- ModuleDict(Module),可以通过key去访问不同的module,本身是一个dict,又是一个module,可以用于module的子module
- ParameterList(Module),把parameter类型参数放入列表中,mm = matrix multiplication,矩阵乘法
- ParameterDict(Module),同上
- Module <-> Parameter,都是container,容器,List和Dict只有存放,没有forward功能,只有Sequential包含forward功能
以上是关于PyTorch - Sequential和ModuleList的主要内容,如果未能解决你的问题,请参考以下文章
如何在 Pytorch 的“nn.Sequential”中展平输入
pytorch中的顺序容器——torch.nn.Sequential
Pytorch学习笔记——Sequential类参数管理与GPU
pytorch教程之nn.Sequential类详解——使用Sequential类来自定义顺序连接模型