HuggingFace 文本摘要输入数据格式问题

Posted

技术标签:

【中文标题】HuggingFace 文本摘要输入数据格式问题【英文标题】:HuggingFace text summarization input data format issue 【发布时间】:2021-10-12 15:33:29 【问题描述】:

我正在尝试微调模型以执行文本摘要。我使用的是AutoModelForSeq2SeqLM.from_pretrained(),因此以下适用于多个模型(例如 T5、ProphetNet、BART)。

我创建了一个名为CustomDataset 的类,它是torch.utils.Dataset 的子类。该类包含一个字段:samples - 具有 encodingslabels 键的字典列表。每个字典中的每个值都是torch.Tensorsamples 中的条目如下所示:

'encoding': tensor([[21603, 10, 188, 563, 1]]), 'label': tensor([[ 1919, 22003, 22, 7, 1]])

这是我尝试使用 Trainer 微调模型的方法:

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

training_args = TrainingArguments("test_trainer")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data,
)
trainer.train()

transformers\data\data_collator.py 的第 63 行抛出了我遇到的错误。这是那行代码:

label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]

这是错误消息: ValueError: only one element tensors can be converted to Python scalars

我理解为什么会特别抛出错误消息 - first["label"] 张量不是单元素张量,因此不能在其上调用 item()。不过,这不是我问这个问题的原因。

我假设我没有正确传递数据,但在我看来 Trainer 应该自己处理 input_idsdecoder_input_ids。我尝试手动设置这些(将encodings 传递为input_ids,将labels 传递为decoder_input_ids)并且模型可以成功执行推理,但我没有设法对其进行微调。我在哪里犯了错误,我该如何解决?

【问题讨论】:

【参考方案1】:

使用名称 label_ids 而不是 label 可以解决特定问题。如果标签是intfloat 或单元素torch.Tensor,则应使用label。对于具有多个元素的张量,请使用 label_ids。详情请参阅data_collator.py,第 62-71 行:

if "label" in first and first["label"] is not None:
    label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
    dtype = torch.long if isinstance(label, int) else torch.float
    batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first and first["label_ids"] is not None:
    if isinstance(first["label_ids"], torch.Tensor):
        batch["labels"] = torch.stack([f["label_ids"] for f in features])
    else:
        dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
        batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)

此外,应使用名称input_ids 而不是encoding。否则,将引发 unknown kwarg 错误。

【讨论】:

以上是关于HuggingFace 文本摘要输入数据格式问题的主要内容,如果未能解决你的问题,请参考以下文章

Transformers学习笔记3. HuggingFace管道函数Pipeline

Fine Tuning Huggingface RobertaForQuestionAnswering 的输入/输出格式

输入文件应该如何格式化以进行语言模型微调(BERT 通过 Huggingface Transformers)?

使用 huggingface 的 distilbert 模型生成文本

阿尔伯特没有收敛 - HuggingFace

如何微调 HuggingFace BERT 模型以进行文本分类 [关闭]