pytorch net.train(), net.eval()
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch net.train(), net.eval()相关的知识,希望对你有一定的参考价值。
参考技术A 在训练模型时会在前面加上:在测试模型时在前面使用:
同时发现,若是不写这两个程序也能够运行,这是由于这两个方法是针对在网络训练和测试时采用不一样方式的状况,好比Batch Normalization 和 Dropout。网络
训练时是正对每一个min-batch的,可是在测试中每每是针对单张图片,即不存在min-batch的概念。因为网络训练完毕后参数都是固定的,所以每一个批次的均值和方差都是不变的,所以直接结算全部batch的均值和方差。全部Batch Normalization的训练和测试时的操做不一样
在训练中,每一个隐层的神经元先乘几率P,而后在进行激活,在测试中,全部的神经元先进行激活,而后每一个隐层神经元的输出乘P。
参考:
https://www.jb51.net/article/212962.htm
https://www.shangmayuan.com/a/e3c4302796b746179926723a.html
model.trian()及model.eval()
net.eval() #评估模式,就是net.train(False)。
设置之后会对前向传播相关进行过滤,会关闭dropout BN等 #如果网络本身没有BN和dropout,那就没区别了。
net.train():默认参数是Train。model.train()会启动drop 和 BN,但是model.train(False)不会
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;
而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。
以上是关于pytorch net.train(), net.eval()的主要内容,如果未能解决你的问题,请参考以下文章