什么是 pytorch 中的 model.training?

Posted

技术标签:

【中文标题】什么是 pytorch 中的 model.training?【英文标题】:what is model.training in pytorch? 【发布时间】:2021-11-21 01:57:13 【问题描述】:

您好,我正在阅读有关迁移学习的 pytorch 教程。 (https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

model.training 是干什么用的?

enter def visualize_model(model,num_images=6):
was_training=model.training
model.eval()
images_so_far=0
fig=plt.figure()

with torch.no_grad():
    for i, (inputs,labels) in enumerate(dataloaders['val']):
        inputs=inputs.to(device)
        labels=labels.to(device)
        
        outputs=model(inputs)
        _,pred=torch.max(outputs,1)
        
        for j in range(inputs.size()[0]):
            images_so_far+=1
            ax=plt.subplot(num_images//2,2,images_so_far)
            ax.axis('off')
            ax.set_title('predicted: '.format(class_names[preds[j]]))
            imshow(inputs.cpu().data[j])
            
            if images_so_far==num_images:
                model.train(mode=was_training)
                return
    model.train(mode=was_training)code here

我无法理解“model.train(model=was_training)”。有什么帮助吗??非常感谢

【问题讨论】:

这能回答你的问题吗? What does model.train() do in PyTorch? 哦,谢谢!但现在我想知道他们为什么在测试会话中使用 model.train。为什么他们把代码放在“with torch.no_grad()”里面?? was_training=false 不是很明显吗?? 【参考方案1】:

我想知道他们为什么在测试会话中使用model.train。为什么他们将代码放在with torch.no_grad() 中? was_training=false 不是很明显吗?

train 的用法有点误导,因为 train 也可用于将模型置于推理(评估)模式

>>> model.train(mode=True)
>>> model.training 
True   # <- train mode

>>> model.train(mode=False)
False  # <- eval mode

我同意这并不理想,更合适的表述应该是:

>>> model.eval()

【讨论】:

【参考方案2】:

我认为这会有所帮助 (link)

所有的nn.Modules都有一个内部的training属性,通过调用model.train()和model.eval()来切换模型的行为。

was_training 变量存储模型的当前训练状态,调用 model.eval(),最后使用 model.train(training=was_training) 重置状态。

您可以在 pytorch 讨论论坛中找到很好的答案;)

【讨论】:

以上是关于什么是 pytorch 中的 model.training?的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch知识点总结100问

什么是Pytorch?掌握Pytorch的基本元素操作运算操作

pytorch 中的 torch.nn.gru 函数的输入是啥?

为啥 PyTorch 中的嵌入实现为稀疏层?

用外行术语来说,pytorch 中的聚集函数有啥作用?

PyTorch 中的无维转置