model.train()model.eval()什么时候用
Posted WTIAW.TIAW
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了model.train()model.eval()什么时候用相关的知识,希望对你有一定的参考价值。
model.train()
在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。
如果模型中有BN层(Batch Normalization)和 Dropout ,需要在训练时添加 model.train()。
model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分网络连接来训练更新参数。
model.eval()
model.eval() 作用等同于 self.train(False)
简而言之,就是评估模式。而非训练模式。
在评估模式下,batchNorm层,dropout层等用于优化训练而添加的网络层会被关闭,从而使得评估时不会发生偏移。
在对模型进行评估时,应该配合使用with torch.no_grad() 与 model.eval()
loop:
model.train() # 切换至训练模式
train……
model.eval()
with torch.no_grad():
Evaluation
end loop
总结:
如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。
其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;
而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接
model.trian()及model.eval()
net.eval() #评估模式,就是net.train(False)。
设置之后会对前向传播相关进行过滤,会关闭dropout BN等 #如果网络本身没有BN和dropout,那就没区别了。
net.train():默认参数是Train。model.train()会启动drop 和 BN,但是model.train(False)不会
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;
而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。
以上是关于model.train()model.eval()什么时候用的主要内容,如果未能解决你的问题,请参考以下文章
pytorch踩坑之model.eval()和model.train()输出差距很大
model.train()model.eval()optimizer.zero_grad()loss.backward()optimizer.step作用及原理详解Pytorch入门手册
model.train()model.eval()optimizer.zero_grad()loss.backward()optimizer.step作用及原理详解Pytorch入门手册
[Pytorch系列-38]: 工具集 - torchvision预定义模型的两种模式model.train和model.eval的表面和本质区别