为啥这个 Pytorch 官方教程中没有 .train() 方法?
Posted
技术标签:
【中文标题】为啥这个 Pytorch 官方教程中没有 .train() 方法?【英文标题】:Why there's no .train() method in this Pytorch official tutorial?为什么这个 Pytorch 官方教程中没有 .train() 方法? 【发布时间】:2019-11-16 23:07:52 【问题描述】:所以,我刚刚学习了 Pytorch,他们说你必须通过 .train() 方法将 NN 置于训练模式,然后在推断 .eval() 模式时。我正在阅读本教程,根本没有 .train() 。这是为什么呢?
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
【问题讨论】:
What does model.train() do in pytorch?的可能重复 请不要同时在 SO 和 Discuss PyTorch 上发布相同的问题。至少,等待一天在其中一个中得到答案。 【参考方案1】:.train()
将模块的self.training
属性设置为True
。 .eval()
将其设置为 False
。
从source for nn.Module
中可以看出,此属性最初设置为True
。因此,除非您在开始训练之前致电eval()
,否则您不一定(根据当前实施)需要致电train()
。 但是重要的是,模块在训练时应该处于self.training=True
状态,所以无论如何这样做可能是一个好习惯。
此外,目前,只有一些模块(例如 dropout 和 batchnorm)会根据 self.training
属性更改其行为。因此,如果您不使用这些特定模块,则不必一定调用 .train()
和 .eval()
,但同样,无论如何,这样做可能是一种很好的做法,可以让您的代码面向未来.
【讨论】:
以上是关于为啥这个 Pytorch 官方教程中没有 .train() 方法?的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch 1.0 中文官方教程:使用 PyTorch 进行图像风格转换
PyTorch 1.0 中文官方教程:使用ONNX将模型从PyTorch传输到Caffe2和移动端