不降低PyTorch版本解决AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘

Posted 夏小悠

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了不降低PyTorch版本解决AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘相关的知识,希望对你有一定的参考价值。

前言

  在项目开发过程中,遇到了一个PyTorch版本更新带来的问题,搜索了相关博客之后,都是降低PyTorch的版本,看似解决了问题,实则不然,治标不治本而已。本篇博客主要介绍一下如何从根源上解决这些问题。

1. 问题描述

# 此处粘贴了问题出错的部分代码
......
with torch.onnx.set_training(model_clone, False):
	device = distiller.model_device(model_clone)
	dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
	self.dummy_input = dummy_input
	......

  运行项目的过程中出现错误:AttributeError: module 'torch.onnx' has no attribute 'set_training'

2. 问题原因

  查询了相关博客说是PyTorch的版本过高需要降低版本即可解决,没错,该项目所附带的requirements.txt要求的PyTorch版本为1.3.1。由于我所开发的项目实是在新版本上(PyTorch 1.8.1)开发的,所以降低PyTorch的版本解决此问题不是一个好的 idea 。

3. 解决过程

  要想解决这个问题,肯定是要知道torch.onnx.set_training()这行代码是干什么的,看了一下PyTorch 1.3.1官方文档,官方是这样说的:

def set_training(model, mode):
    r"""
    A context manager to temporarily set the training mode of 'model'
    to 'mode', resetting it when we exit the with-block.  A no-op if
    mode is None.
    """

    from torch.onnx import utils
    return utils.set_training(model, mode)

  大致意思就是说,在和一个with上下文管理器一起用是,其作用就是临时将model设置为mode模式,退出with再将modelmode重置。
  然后我看了一下官方的源码,发现也是很简单,就一行。我尝试在新版本中导出一下from torch.onnx.utils import set_training,发现没有此函数,ennnnn,合情合理,但有了另一个函数from torch.onnx.utils import select_model_mode_for_export,从函数名字上感觉很像,而且参数一毛一样,看了一下PyTorch 1.8.1版本的官方文档

def select_model_mode_for_export(model, mode):
    r"""
    A context manager to temporarily set the training mode of 'model'
    to 'mode', resetting it when we exit the with-block.  A no-op if
    mode is None.

    In version 1.6 changed to this from set_training
    """

    from torch.onnx import utils
    return utils.select_model_mode_for_export(model, mode)

  很明确了,在PyTorch 1.6版本中set_training变成了select_model_mode_for_export,改一下就可以了。

以上是关于不降低PyTorch版本解决AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘的主要内容,如果未能解决你的问题,请参考以下文章

第二讲Pytorch使用

为啥损失减少但准确性也降低(Pytorch,LSTM)?

安装GPU版本的pytorch(解决pytorch安装时默认安装CPU版本的问题)保姆级教程

conda安装GPU版pytorch,结果却是cpu版本[找到问题根源,从容解决]

[深度学习][pytorch][原创]crnn在高版本pytorch上训练loss为nan解决办法

如何使用poi 解决word2003中文符替换问题,我现在用的2007的高版本,想降低版本,不知道怎么实现?