如何在使用 ONNX 推理会话时通过传递“标签”来获得语言建模损失?

Posted

技术标签:

【中文标题】如何在使用 ONNX 推理会话时通过传递“标签”来获得语言建模损失?【英文标题】:How to get the language modeling loss by passing 'labels' while using ONNX inference session? 【发布时间】:2021-10-14 04:26:37 【问题描述】:

当使用 GPT2 时,我们可以简单地传递 'labels' 参数来获取损失,如下所示:

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss

但是,无法找出如何在 ONNX 推理会话中获得相同的损失。我正在使用下面的代码,它只返回“last_hidden_​​state”:

import onnxruntime as ort

from transformers import GPT2TokenizerFast
#tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

ort_session = ort.InferenceSession("onnx/gpt2/model.onnx")

inputs = tokenizer("Using BERT in ONNX!", return_tensors="np")
outputs = ort_session.run(["last_hidden_state"], dict(inputs))

【问题讨论】:

【参考方案1】:

“onnx/gpt2/model.onnx”是如何生成的?

看起来 PyTorch 运行使用 transformers.GPT2LMHeadModel,而 ORT 运行使用 transformers.GPT2Model,这是一个“裸 GPT2 模型转换器,输出原始隐藏状态,没有任何特定的头部”并且不返回损耗。

【讨论】:

您好,我使用 Gpt2Helper.py 中的 MyGPT2LMHeadModel 生成 'model.onnx' 并使用 export_onnx 方法保存。 ``` from onnxruntime.transformers.gpt2_helper import Gpt2Helper, MyGPT2LMHeadModel from transformers import AutoConfig model_name_or_path = "gpt2" config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) model = MyGPT2LMHeadModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir) device = torch.device("cpu") model.eval().to(device) onnx_model_path = "onnx/gpt2/gpt2.onnx" Gpt2Helper.export_onnx(model, device, onnx_model_path) ``` 我能够对 GPTLMHeadModel 的 forward() 方法进行更改以发送“labels = input_ids”,它返回“loss”作为第一个输出,它帮助解决了我的问题。

以上是关于如何在使用 ONNX 推理会话时通过传递“标签”来获得语言建模损失?的主要内容,如果未能解决你的问题,请参考以下文章

模型推理加速系列BERT加速方案对比 TorchScript vs. ONNX

模型推理加速系列04:BERT加速方案对比 TorchScript vs. ONNX

Pytorch的pth模型转onnx,再用ONNX Runtime调用推理(附python代码)

对象检测模型 (PyTorch) 到 ONNX:ONNX 推理的空输出

NLP涉及技术原理和应用简单讲解:paddle(梯度裁剪ONNX协议动态图转静态图推理部署)

NLP涉及技术原理和应用简单讲解:paddle(梯度裁剪ONNX协议动态图转静态图推理部署)