不降低PyTorch版本解决AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘
Posted 夏小悠
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了不降低PyTorch版本解决AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘相关的知识,希望对你有一定的参考价值。
前言
在项目开发过程中,遇到了一个PyTorch
版本更新带来的问题,搜索了相关博客之后,都是降低PyTorch
的版本,看似解决了问题,实则不然,治标不治本而已。本篇博客主要介绍一下如何从根源上解决这些问题。
1. 问题描述
# 此处粘贴了问题出错的部分代码
......
with torch.onnx.set_training(model_clone, False):
device = distiller.model_device(model_clone)
dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
self.dummy_input = dummy_input
......
运行项目的过程中出现错误:AttributeError: module 'torch.onnx' has no attribute 'set_training'
。
2. 问题原因
查询了相关博客说是PyTorch
的版本过高需要降低版本即可解决
,没错,该项目所附带的requirements.txt
要求的PyTorch
版本为1.3.1
。由于我所开发的项目实是在新版本上(PyTorch 1.8.1)
开发的,所以降低PyTorch
的版本解决此问题不是一个好的 idea 。
3. 解决过程
要想解决这个问题,肯定是要知道torch.onnx.set_training()
这行代码是干什么的,看了一下PyTorch 1.3.1
的官方文档,官方是这样说的:
def set_training(model, mode):
r"""
A context manager to temporarily set the training mode of 'model'
to 'mode', resetting it when we exit the with-block. A no-op if
mode is None.
"""
from torch.onnx import utils
return utils.set_training(model, mode)
大致意思就是说,在和一个with
上下文管理器一起用是,其作用就是临时将model
设置为mode
模式,退出with
再将model
的mode
重置。
然后我看了一下官方的源码,发现也是很简单,就一行。我尝试在新版本中导出一下from torch.onnx.utils import set_training
,发现没有此函数,ennnnn,合情合理,但有了另一个函数from torch.onnx.utils import select_model_mode_for_export
,从函数名字上感觉很像,而且参数一毛一样,看了一下PyTorch 1.8.1
版本的官方文档:
def select_model_mode_for_export(model, mode):
r"""
A context manager to temporarily set the training mode of 'model'
to 'mode', resetting it when we exit the with-block. A no-op if
mode is None.
In version 1.6 changed to this from set_training
"""
from torch.onnx import utils
return utils.select_model_mode_for_export(model, mode)
很明确了,在PyTorch 1.6
版本中set_training
变成了select_model_mode_for_export
,改一下就可以了。
以上是关于不降低PyTorch版本解决AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘的主要内容,如果未能解决你的问题,请参考以下文章
安装GPU版本的pytorch(解决pytorch安装时默认安装CPU版本的问题)保姆级教程
conda安装GPU版pytorch,结果却是cpu版本[找到问题根源,从容解决]