EncoderDecoderModel 转换解码器的分类器层
Posted
技术标签:
【中文标题】EncoderDecoderModel 转换解码器的分类器层【英文标题】:EncoderDecoderModel converts classifier layer of decoder 【发布时间】:2021-12-10 23:53:04 【问题描述】:我正在尝试使用序列到序列模型进行命名实体识别。我的输出是简单的 IOB 标记,因此我只想预测每个标记 (IOB) 的 3 个标签的概率。
我正在尝试使用 HuggingFace 实现的 EncoderDecoderModel,其中 DistilBert 作为我的编码器,BertForTokenClassification 作为我的解码器。
首先,我导入我的编码器和解码器:
encoder = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
encoder.save_pretrained("Encoder")
decoder = BertForTokenClassification.from_pretrained('bert-base-uncased',
num_labels=3,
output_hidden_states=False,
output_attentions=False)
decoder.save_pretrained("Decoder")
decoder
当我检查我的解码器模型时,我可以清楚地看到 out_features=3 的线性分类层:
## sample of output:
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=3, bias=True)
)
但是,当我在我的 EncoderDecoderModel 中组合这两个模型时,似乎解码器被转换为不同类型的分类器 - 现在将 out_features 作为我词汇量的大小:
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("./Encoder","./Decoder")
bert2bert
## sample of output:
(cls): BertOnlyMLMHead(
(predictions): BertLMPredictionHead(
(transform): BertPredictionHeadTransform(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(decoder): Linear(in_features=768, out_features=30522, bias=True)
)
)
这是为什么呢?我怎样才能在我的模型中保留 out_features = 3?
【问题讨论】:
【参考方案1】:Huggingface 对其模型使用不同的头部(取决于网络和任务)。虽然这些模型的一部分是相同的(例如上下文编码器模块),但它们在最后一层即头部本身有所不同。
例如,对于分类问题,他们使用XForSequenceClassification
头,其中X
是Bert、Bart 等语言模型的名称。
话虽如此,EncoderDecoderModel
模型使用语言建模头,而您已经存储的解码器使用分类头。当EncoderDecoderModel
看到这些差异时,它使用自己的LMhead
,这是一个线性层,in_features 为 768,映射到 30522 作为词汇表的数量。
为了规避这个问题,您可以使用 vanilla BERTModel 类来输出隐藏表示,然后添加一个线性层进行分类,该层接受与形状为 768 的 BERT 的 [CLS]
标记相关联的嵌入,然后将其通过线性层映射到 3 的输出向量,即标签的数量。
【讨论】:
感谢您的回答!你将如何实现它?我可以在bert2bert
类中添加它吗?
就目前而言,EncoderDecoderModel
的使用在这里已被取消,因为它是为语言建模任务而设计的,而不是序列分类。我的建议是使用BertModel
类为输入生成原始隐藏表示,然后在BERT 的嵌入之上手动添加nn.Linear
层以进行分类。以上是关于EncoderDecoderModel 转换解码器的分类器层的主要内容,如果未能解决你的问题,请参考以下文章
Xamarin 安卓。将字节数组转换为位图。 Skia 解码器返回 false