HuggingFace 文本摘要输入数据格式问题
Posted
技术标签:
【中文标题】HuggingFace 文本摘要输入数据格式问题【英文标题】:HuggingFace text summarization input data format issue 【发布时间】:2021-10-12 15:33:29 【问题描述】:我正在尝试微调模型以执行文本摘要。我使用的是AutoModelForSeq2SeqLM.from_pretrained()
,因此以下适用于多个模型(例如 T5、ProphetNet、BART)。
我创建了一个名为CustomDataset
的类,它是torch.utils.Dataset
的子类。该类包含一个字段:samples
- 具有 encodings
和 labels
键的字典列表。每个字典中的每个值都是torch.Tensor
。 samples
中的条目如下所示:
'encoding': tensor([[21603, 10, 188, 563, 1]]), 'label': tensor([[ 1919, 22003, 22, 7, 1]])
这是我尝试使用 Trainer 微调模型的方法:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
training_args = TrainingArguments("test_trainer")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=data,
)
trainer.train()
transformers\data\data_collator.py
的第 63 行抛出了我遇到的错误。这是那行代码:
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
这是错误消息:
ValueError: only one element tensors can be converted to Python scalars
我理解为什么会特别抛出错误消息 - first["label"]
张量不是单元素张量,因此不能在其上调用 item()
。不过,这不是我问这个问题的原因。
我假设我没有正确传递数据,但在我看来 Trainer
应该自己处理 input_ids
和 decoder_input_ids
。我尝试手动设置这些(将encodings
传递为input_ids
,将labels
传递为decoder_input_ids
)并且模型可以成功执行推理,但我没有设法对其进行微调。我在哪里犯了错误,我该如何解决?
【问题讨论】:
【参考方案1】:使用名称 label_ids
而不是 label
可以解决特定问题。如果标签是int
、float
或单元素torch.Tensor
,则应使用label
。对于具有多个元素的张量,请使用 label_ids
。详情请参阅data_collator.py
,第 62-71 行:
if "label" in first and first["label"] is not None:
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
dtype = torch.long if isinstance(label, int) else torch.float
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first and first["label_ids"] is not None:
if isinstance(first["label_ids"], torch.Tensor):
batch["labels"] = torch.stack([f["label_ids"] for f in features])
else:
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
此外,应使用名称input_ids
而不是encoding
。否则,将引发 unknown kwarg
错误。
【讨论】:
以上是关于HuggingFace 文本摘要输入数据格式问题的主要内容,如果未能解决你的问题,请参考以下文章
Transformers学习笔记3. HuggingFace管道函数Pipeline
Fine Tuning Huggingface RobertaForQuestionAnswering 的输入/输出格式
输入文件应该如何格式化以进行语言模型微调(BERT 通过 Huggingface Transformers)?