关于pytorch中,self.training的理解
Posted AI浩
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了关于pytorch中,self.training的理解相关的知识,希望对你有一定的参考价值。
最近在看源码的过程中看到了有些模型的forward函数中self.training判断训练还是推理的状态。 这个是如何做到呢?下面我通过源码分析一下:
if self.training:
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
这段代码来自deit的代码,在训练的时候,你会发现self.training为True,在推理的时候self.training为False,如果直接搜索training这个字段,你发现只有一个结果,没有看到在哪里赋值,只有这一处在使用,我们继续寻找父类VisionTransformer,发现VisionTransformer里面也没有,只能继续寻找父类nn.Module,在这个类里面找到了。
class Module:
r"""Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.
.. note::
As per the example above, an ``__init__()`` call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or
evaluation mode.
:vartype training: bool
"""
dump_patches: bool = False
_version: int = 1
r"""This allows better BC support for :meth:`load_state_dict`. In
:meth:`state_dict`, the version number will be saved as in the attribute
`_metadata` of the returned state dict, and thus pickled. `_metadata` is a
dictionary with keys that follow the naming convention of state dict. See
``_load_from_state_dict`` on how to use this information in loading.
If new parameters/buffers are added/removed from a module, this number shall
be bumped, and the module's `_load_from_state_dict` method can compare the
version number and do appropriate changes if the state dict is from before
the change."""
training: bool
_parameters: Dict[str, Optional[Parameter]]
_buffers: Dict[str, Optional[Tensor]]
_non_persistent_buffers_set: Set[str]
_backward_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_state_dict_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_load_state_dict_post_hooks: Dict[int, Callable]
_modules: Dict[str, Optional['Module']]
我们继续在源码里寻找,直到看到train()和eval函数才明白:
def train(self: T, mode: bool = True) -> T:
r"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for module in self.children():
module.train(mode)
return self
def eval(self: T) -> T:
r"""Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
See :ref:`locally-disable-grad-doc` for a comparison between
`.eval()` and several similar mechanisms that may be confused with it.
Returns:
Module: self
"""
return self.train(False)
在我们训练的时候会执行model.train(),这时候将training设置为True,在执行eval()的时候,将training设置为false。
以上是关于关于pytorch中,self.training的理解的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch - Sequential和ModuleList
『PyTorch』第十四弹_torch.nn.Module深入分析
Self-Paced Training - Docker Operations