预期的输入batch_size以匹配目标batch_size(11)
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了预期的输入batch_size以匹配目标batch_size(11)相关的知识,希望对你有一定的参考价值。
我知道这似乎是一个普遍的问题,但是我找不到解决方案。我正在运行一个多标签分类模型,并且张量大小有问题。
我的完整代码如下:
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch
# Instantiating tokenizer and model
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased')
# Instantiating quantized model
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
# Forming data tensors
input_ids = torch.tensor(tokenizer.encode(x_train[0], add_special_tokens=True)).unsqueeze(0)
labels = torch.tensor(Y[0]).unsqueeze(0)
# Train model
outputs = quantized_model(input_ids, labels=labels)
loss, logits = outputs[:2]
哪个会产生错误:
ValueError: Expected input batch_size (1) to match target batch_size (11)
Input_ids看起来像:
tensor([[ 101, 789, 160, 1766, 1616, 1110, 170, 1205, 7727, 1113,
170, 2463, 1128, 1336, 1309, 1138, 112, 119, 11882, 11545,
119, 108, 15710, 108, 3645, 108, 3994, 102]])
具有形状:
torch.Size([1, 28])
并且标签看起来像:
tensor([[0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1]])
具有形状:
torch.Size([1, 11])
input_ids的大小将随着要编码的字符串的大小而变化。
[我还注意到,当输入5的Y值以产生5个标签时,会产生错误:
ValueError: Expected input batch_size (1) to match target batch_size (55).
带有标签形状:
torch.Size([1, 5, 11])
((请注意,我没有输入5个input_id,这大概就是为什么输入大小保持不变的原因)
我已经尝试了几种不同的方法来使它们起作用,但是我现在很茫然。我真的很感谢一些指导。谢谢!
DistilBertForSequenceClassification
的标签必须具有文档中提到的大小DistilBertForSequenceClassification
:
- 标签(形状为
torch.Size([batch_size])
的torch.LongTensor
,可选,默认为(batch_size,)
)–用于计算序列分类/回归损失的标签。索引应在None
中。如果[0, ..., config.num_labels - 1]
,则计算回归损失(均方根损失);如果config.num_labels == 1
,则计算分类损失(交叉熵)。在您的情况下,您的
config.num_labels > 1
的大小应为labels
。
这对您的数据来说是不可能的,这是因为序列分类为每个序列都有一个标签,但是您希望将其设为多标签分类。
据我所知,HuggingFace的转换器库中没有多标签模型,您可以直接使用。您将需要创建自己的模型,这并不是特别困难,因为这些额外的模型都使用相同的基本模型,并根据要解决的任务最后添加适当的分类器。 torch.Size([1])
解释了如何做到这一点。
以上是关于预期的输入batch_size以匹配目标batch_size(11)的主要内容,如果未能解决你的问题,请参考以下文章
ValueError:预期输入 batch_size (59) 与目标 batch_size (1) 匹配