model.trian()及model.eval()
Posted henry-zhao
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了model.trian()及model.eval()相关的知识,希望对你有一定的参考价值。
net.eval() #评估模式,就是net.train(False)。
设置之后会对前向传播相关进行过滤,会关闭dropout BN等 #如果网络本身没有BN和dropout,那就没区别了。
net.train():默认参数是Train。model.train()会启动drop 和 BN,但是model.train(False)不会
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;
而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。
以上是关于model.trian()及model.eval()的主要内容,如果未能解决你的问题,请参考以下文章
model.train()model.eval()optimizer.zero_grad()loss.backward()optimizer.step作用及原理详解Pytorch入门手册
Pytorch中的 model.train() 和 model.eval() 模式
model.train()model.eval()什么时候用