如何从预训练的 GPT2 模型中获取 onnx 格式?
Posted
技术标签:
【中文标题】如何从预训练的 GPT2 模型中获取 onnx 格式?【英文标题】:How to get onnx format from pretrained GPT2 models? 【发布时间】:2021-08-29 16:48:21 【问题描述】:我正在尝试将由 GPT2 预训练的 KoGPT2 模型转换为 onnx 格式,以便将模型更改为 tensorflow 格式。
我在transformers
中使用了convert_graph_to_onnx
,但由于某些原因它不起作用。
我不知道这个错误意味着什么。这个模型可以制作onnx格式吗?这是我实现的代码,最后一个是错误。
谢谢。
import sys
!sys.executable -m pip install --upgrade git+https://github.com/huggingface/transformers
!sys.executable -m pip install --upgrade torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
!sys.executable -m pip install --upgrade onnxruntime==1.4.0
!sys.executable -m pip install -i https://test.pypi.org/simple/ ort-nightly
!sys.executable -m pip install --upgrade onnxruntime-tools
!rm -rf onnx/
from pathlib import Path
from transformers.convert_graph_to_onnx import convert
# Handles all the above steps for you
convert(framework="pt", model="skt/kogpt2-base-v2", output=Path('/content/drive/MyDrive/kogptonnx/kogpt.onnx'), opset=12)
# Tensorflow
# convert(framework="tf", model="bert-base-cased", output="onnx/bert-base-cased.onnx", opset=11)
ONNX opset version set to: 11
Loading pipeline (model: skt/kogpt2-base-v2, tokenizer: skt/kogpt2-base-v2)
Some weights of the model checkpoint at skt/kogpt2-base-v2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Using framework PyTorch: 1.6.0+cpu
Found input input_ids with shape: 0: 'batch', 1: 'sequence'
Found input attention_mask with shape: 0: 'batch', 1: 'sequence'
Found output output_0 with shape: 0: 'batch', 1: 'sequence'
Found output output_1 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_1 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_2 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_2 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_3 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_3 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_4 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_4 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_5 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_5 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_6 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_6 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_7 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_7 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_8 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_8 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_9 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_9 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_10 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_10 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_11 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_11 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_12 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Found output output_12 with shape: 0: 'batch', 1: 'sequence', 2: 'sequence'
Ensuring inputs are in correct order
past_key_values is not present in the generated input list.
Generated inputs order: ['input_ids']
/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_gpt2.py:181: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_gpt2.py:186: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-6-78cc7242cbdd> in <module>()
4
5 # Handles all the above steps for you
----> 6 convert(framework="pt", model="skt/kogpt2-base-v2", output=Path('/content/drive/MyDrive/kogptonnx/kogpt.onnx'), opset=11)
7
8 # Tensorflow
6 frames
/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict)
187
188 graph = torch._C._jit_pass_onnx(graph, operator_export_type)
--> 189 torch._C._jit_pass_lint(graph)
190
191 torch._C._jit_pass_onnx_scalar_type_analysis(graph)
RuntimeError: Unable to cast from non-held to held instance (T& to Holder<T>) (compile in debug mode for type information)
【问题讨论】:
【参考方案1】:不确定它在这里是否有帮助,但对于不同的型号(还有变压器),我有相同的 Unable to cast from non-held to held instance
错误消息,在我的情况下,将 operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
选项添加到 torch.onnx.export(...)
(as mentioned here ) 为我修好了:
torch.onnx.export(model, input, "output-name.onnx", export_params=True, opset_version=12, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
【讨论】:
以上是关于如何从预训练的 GPT2 模型中获取 onnx 格式?的主要内容,如果未能解决你的问题,请参考以下文章
如何在使用 ONNX 推理会话时通过传递“标签”来获得语言建模损失?