使用 huggingface pytorch-transformers GPT-2 进行分类任务

Posted

技术标签:

【中文标题】使用 huggingface pytorch-transformers GPT-2 进行分类任务【英文标题】:using huggingface's pytorch- transformers GPT-2 for classifcation tasks 【发布时间】:2019-12-06 10:55:08 【问题描述】:

我想使用 GPT-2 来制作文本分类器模型。通过 GPT-2 提取特征后,我不确定应该添加什么头。例如,我有一个序列。

import pytorch_transformers as pt 
import torch
text=test.iloc[1,1]
text
'If a fire wanted fanning, it could readily be fanned with a newspaper, and as the government grew weaker, I have no doubt that leather and iron acquired durability in proportion, for, in a very short time, there was not a pair of bellows in all Rotterdam that ever stood in need of a stitch or required the assistance of a hammer.'
len(text)

74
tokenizer = pt.GPT2Tokenizer.from_pretrained('gpt2')
model = pt.GPT2Model.from_pretrained('gpt2')
zz = tokenizer.tokenize(text)
z1=torch.tensor([tokenizer.convert_tokens_to_ids(zz)])
z1
tensor([[ 1532,   257,  2046,  2227,  4336,   768,    11,   340,   714, 14704,
           307,   277,  3577,   351,   257,  7533,    11,   290,   355,   262,
          1230,  6348, 17642,    11,   314,   423,   645,  4719,   326, 11620,
           290,  6953,  9477, 26578,   287,  9823,    11,   329,    11,   287,
           257,   845,  1790,   640,    11,   612,   373,   407,   257,  5166,
           286,  8966,  1666,   287,   477, 18481,   353, 11043,   326,  1683,
          6204,   287,   761,   286,   257, 24695,   393,  2672,   262,  6829,
           286,   257, 15554,    13]])
output,hidden=model(z1)
ouput.shape
torch.Size([1, 74, 768])

GPT2 的输出对我来说是 nxmx 768,其中 n 是批量大小,m 是序列中的标记数(例如,我可以填充/截断为 128。),所以我不能做论文说对于分类任务只需在尾部添加一个全连接层。我在谷歌上搜索,很少提到 GPT-2 分类任务。 我不确定什么是正确的。我应该在全连接层之前做flatten/max pooling/average pooling还是别的什么?

【问题讨论】:

我试过平均池,结果还不错。验证 logloss 比 BERT 模型稍小。但我不确定做对了 你有没有发现这个问题?你还在做平均池化吗? 【参考方案1】:

“所以我不能像论文所说的那样做分类任务,只需在尾部添加一个全连接层。” - 这是您问题的答案

通常,像 BERT 和 Roberta 这样的转换器具有双向自注意力,并且它们具有 [CLS] 令牌,我们将其输入到分类器中。由于 GPT-2 是左右排列的,因此您需要输入嵌入序列的最终标记。

P.S - 你能把论文的链接放上去吗?

【讨论】:

【参考方案2】:

如果您使用 GPT-2 构建模型进行文本分类,请分享。

【讨论】:

欢迎来到 SO!请不要发布 cmets 作为答案。您可以在 cmets 部分写下您的问题。

以上是关于使用 huggingface pytorch-transformers GPT-2 进行分类任务的主要内容,如果未能解决你的问题,请参考以下文章

使用 huggingface 库会报错:KeyError: 'logits'

将 AllenNLP 解释与 HuggingFace 模型一起使用

使用 HuggingFace 微调 ALBERT 问答

通过 Huggingface 转换器更新 BERT 模型

使用 Huggingface TFTrainer 类微调模型时如何指定损失函数?

如何在 Huggingface 中从 CSV 加载自定义数据集