Pytorch 到 ONNX 导出功能失败并导致遗留功能错误
Posted
技术标签:
【中文标题】Pytorch 到 ONNX 导出功能失败并导致遗留功能错误【英文标题】:Pytorch to ONNX export function fails and causes legacy function error 【发布时间】:2020-02-18 21:45:58 【问题描述】:我正在尝试使用以下代码将 this 链接中的 pytorch 模型转换为 onnx 模型:
device=t.device('cuda:0' if t.cuda.is_available() else 'cpu')
print(device)
faster_rcnn = FasterRCNNVGG16()
trainer = FasterRCNNTrainer(faster_rcnn).cuda()
#trainer = FasterRCNNTrainer(faster_rcnn).to(device)
trainer.load('./checkpoints/model.pth')
dummy_input = t.randn(1, 3, 300, 300, device = 'cuda')
#dummy_input = dummy_input.to(device)
t.onnx.export(faster_rcnn, dummy_input, "model.onnx", verbose = True)
但我收到以下错误(抱歉,*** 下方的块引用不会让整个跟踪采用代码格式,否则不会让问题发布):
Traceback (most recent call last): small_object_detection_master_samirsen\onnxtest.py", line 44, in <module> t.onnx.export(faster_rcnn, dummy_input, "fasterrcnn_10120119_06025842847785781.onnx", verbose = True) File "C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\onnx\__init__.py",
第 132 行,在导出中 strip_doc_string,动态轴) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\onnx\utils.py”, 第 64 行,出口 example_outputs=example_outputs,strip_doc_string=strip_doc_string,dynamic_axes=dynamic_axes) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\onnx\utils.py”, 第 329 行,在 _export 中 _retain_param_name,do_constant_folding) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\onnx\utils.py”, 第 213 行,在 _model_to_graph 图,torch_out = _trace_and_get_graph_from_model(模型,参数,训练) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\onnx\utils.py”, 第 171 行,在 _trace_and_get_graph_from_model 跟踪,torch_out = torch.jit.get_trace_graph(模型,参数,_force_outplace=True) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\jit__init__.py”, 第 256 行,在 get_trace_graph 中 return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py”, 第 547 行,在 调用 结果 = self.forward(*input, **kwargs) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\jit__init__.py”, 第 323 行,向前 out = self.inner(*trace_inputs) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py”, 第 545 行,在 调用 结果 = self._slow_forward(*input, **kwargs) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py”, 第 531 行,在_slow_forward 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py”, 第 531 行,在_slow_forward 结果 = self.forward(*input, **kwargs) 文件“D:\smallobject2\export test s\small_object_detection_master_samirsen\model\faster_rcnn.py”,行 133,前进 h, rois, roi_indices) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py”, 第 545 行,在 调用 结果 = self._slow_forward(*input, **kwargs) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py”, 第 531 行,在_slow_forward 结果 = self.forward(*input, **kwargs) 文件“D:\smallobject2\export test s\small_object_detection_master_samirsen\model\faster_rcnn_vgg16.py”, 第 142 行,向前 池 = self.roi(x, indices_and_rois) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py”, 第 545 行,在 调用 结果 = self._slow_forward(*input, **kwargs) 文件“C:\Users\HP\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py”, 第 531 行,在_slow_forward 结果 = self.forward(*input, **kwargs) 文件“D:\smallobject2\export test s\small_object_detection_master_samirsen\model\roi_module.py”,行 85、前锋 返回 self.RoI(x, rois) RuntimeError: 已尝试跟踪 RoI,但不支持跟踪遗留函数
【问题讨论】:
看起来该模型正在执行 pytorch 导出到 onnx 功能不支持的操作。支持的运算符列表可以在here找到。 代码中包含一些 C 函数以及 python 代码。我想这些都会引起问题。有没有办法让这个出口成功?有没有人尝试过或者有人可以尝试这个模型导出? 【参考方案1】:这是因为 ONNX 不支持 torch.grad.Function。问题是因为 ROI 类 Refer this
要解决这个问题,您必须将前向和后向函数实现为单独的函数定义,而不是 ROI 类的成员。 FasterRCNNVGG16 中对 ROI 的函数调用应该被更改为显式调用前向和后向函数。
【讨论】:
以上是关于Pytorch 到 ONNX 导出功能失败并导致遗留功能错误的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch 1.0 中文官方教程:使用ONNX将模型从PyTorch传输到Caffe2和移动端