model.train()model.eval()什么时候用

Posted WTIAW.TIAW

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了model.train()model.eval()什么时候用相关的知识,希望对你有一定的参考价值。

model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalizationdropout

如果模型中有BN层(Batch Normalization)Dropout ,需要在训练时添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropoutmodel.train() 是 随机取一部分网络连接来训练更新参数。

model.eval()

model.eval() 作用等同于 self.train(False)
简而言之,就是评估模式。而非训练模式。
在评估模式下,batchNorm层,dropout层等用于优化训练而添加的网络层会被关闭,从而使得评估时不会发生偏移。

在对模型进行评估时,应该配合使用with torch.no_grad() 与 model.eval()

    loop:
        model.train()    # 切换至训练模式
        train……
        model.eval()
        with torch.no_grad():
            Evaluation
    end loop

总结:

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接

model.trian()及model.eval()

net.eval() #评估模式,就是net.train(False)。
设置之后会对前向传播相关进行过滤,会关闭dropout BN等 #如果网络本身没有BN和dropout,那就没区别了。

net.train():默认参数是Train。model.train()会启动drop 和 BN,但是model.train(False)不会

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;
而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。



以上是关于model.train()model.eval()什么时候用的主要内容,如果未能解决你的问题,请参考以下文章

model.trian()及model.eval()

pytorch踩坑之model.eval()和model.train()输出差距很大

model.train()model.eval()optimizer.zero_grad()loss.backward()optimizer.step作用及原理详解Pytorch入门手册

model.train()model.eval()optimizer.zero_grad()loss.backward()optimizer.step作用及原理详解Pytorch入门手册

[Pytorch系列-38]: 工具集 - torchvision预定义模型的两种模式model.train和model.eval的表面和本质区别

pytorch 笔记:validation ,model.eval V.S torch.no_grad