当想要找到具有最高 `start` 分数的标记时,torch.argmax() 中的 TypeError

Posted

技术标签:

【中文标题】当想要找到具有最高 `start` 分数的标记时,torch.argmax() 中的 TypeError【英文标题】:TypeError in torch.argmax() when want to find the tokens with the highest `start` score 【发布时间】:2021-11-13 07:20:25 【问题描述】:

我想运行此代码以使用拥抱面部转换器回答问题。

import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

question = '''Why was the student group called "the Methodists?"'''

paragraph = ''' The movement which would become The United Methodist Church began in the mid-18th century within the Church of England.
            A small group of students, including John Wesley, Charles Wesley and George Whitefield, met on the Oxford University campus.
            They focused on Bible study, methodical study of scripture and living a holy life.
            Other students mocked them, saying they were the "Holy Club" and "the Methodists", being methodical and exceptionally detailed in their Bible study, opinions and disciplined lifestyle.
            Eventually, the so-called Methodists started individual societies or classes for members of the Church of England who wanted to live a more religious life. '''
            
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph)

inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens

start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(start_scores)

但我在最后一行收到此错误:

Exception has occurred: TypeError
argmax(): argument 'input' (position 1) must be Tensor, not str
  File "D:\bert\QuestionAnswering.py", line 33, in <module>
    start_index = torch.argmax(start_scores)

我不知道怎么了。谁能帮帮我?

【问题讨论】:

【参考方案1】:

BertForQuestionAnswering 返回一个QuestionAnsweringModelOutput 对象。

由于您将BertForQuestionAnswering 的输出设置为start_scores, end_scores,因此返回的QuestionAnsweringModelOutput 对象被强制转换为字符串元组('start_logits', 'end_logits'),从而导致类型不匹配错误。

以下应该有效:

outputs = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(outputs.start_logits)

【讨论】:

【参考方案2】:

Huggingface 转换器提供了一种运行模型的简单高级方法,如guide 所示:

from transformers import pipeline

nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)
print(nlp(question=question, context=paragraph, topk=5))

topk 允许选择几个得分最高的答案。

【讨论】:

以上是关于当想要找到具有最高 `start` 分数的标记时,torch.argmax() 中的 TypeError的主要内容,如果未能解决你的问题,请参考以下文章

当他们的分数相等时,如何将排名分配给他们共享最高排名的学生?

如何使用 Linq 查询查找分数范围(没有最高分数)

评估 R^2 分数时出错

如何选择具有父子关系的记录对父级具有最高分数

Python:从txt读取分数并显示最高分数和记录

具有多列的数组中最接近的最高值