Torch JIT Trace = TracerWarning:将张量转换为 Python 布尔值可能会导致跟踪不正确
Posted
技术标签:
【中文标题】Torch JIT Trace = TracerWarning:将张量转换为 Python 布尔值可能会导致跟踪不正确【英文标题】:Torch JIT Trace = TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect 【发布时间】:2021-06-19 02:42:54 【问题描述】:我正在关注本教程:https://huggingface.co/transformers/torchscript.html
创建我的自定义 BERT 模型的跟踪,但是当运行完全相同的 dummy_input
时,我收到一个错误:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
We cant record the data flow of Python values, so this value will be treated as a constant in the future.
在我的模型和标记器中加载后,创建跟踪的代码如下:
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)
# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
traced_model = torch.jit.trace(model, dummy_input)
dummy_input
是张量列表,所以我不确定Boolean
类型在哪里发挥作用。有谁知道为什么会发生此错误以及是否发生布尔转换?
非常感谢
【问题讨论】:
【参考方案1】:这个错误是什么意思
当尝试torch.jit.trace
具有数据相关控制流的模型时,会出现此警告。
这个简单的例子应该更清楚:
import torch
class Foo(torch.nn.Module):
def forward(self, tensor):
# It is data dependent
# Trace will only work with one path
if tensor.max() > 0.5:
return tensor ** 2
return tensor
model = Foo()
traced = torch.jit.script(model) # No warnings
traced = torch.jit.trace(model, torch.randn(10)) # Warning
本质上,BERT 模型有一些依赖于数据的控制流(如if
、for
循环),因此您会收到警告。
警告本身
你可以看到BERTforward
代码here。
如果:
参数不会改变(例如传递给forward
的None
值)并且在script
之后将保持不变(例如在推理调用期间)
如果存在基于__init__
内部收集的数据的控制流(如配置),因为这不会改变
例如:
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
只会与torch.jit.trace
作为一个分支运行,因为它只是跟踪张量上的操作并且不知道这样的控制流。
HuggingFace 团队可能已经意识到这一点,并且此警告不是问题(尽管您可能会仔细检查您的用例或尝试使用 torch.jit.script
)
使用torch.jit.script
这很难,因为整个模型必须与 torchscript
兼容(torchscript
有一个可用的 Python 子集,而且很可能无法与 BERT 一起使用)。
仅在必要时才这样做(可能不会)。
【讨论】:
以上是关于Torch JIT Trace = TracerWarning:将张量转换为 Python 布尔值可能会导致跟踪不正确的主要内容,如果未能解决你的问题,请参考以下文章
如何正确地将 cv::Mat 转换为具有完美匹配值的 torch::Tensor?
pytorch 可以优化顺序操作(如张量流图或 JAX 的 jit)吗?