使用 HuggingFace 库在 Pytorch 中训练 n% 的最后一层 BERT(训练 12 个中的最后 5 个 BERTLAYER。)

Posted

技术标签:

【中文标题】使用 HuggingFace 库在 Pytorch 中训练 n% 的最后一层 BERT(训练 12 个中的最后 5 个 BERTLAYER。)【英文标题】:Train n% last layers of BERT in Pytorch using HuggingFace Library (train Last 5 BERTLAYER out of 12 .) 【发布时间】:2021-03-09 01:39:44 【问题描述】:

Bert 的架构类似于 encoder -> 12 BertLayer -> Pooling。我想训练 Bert 模型的最后 40% 层。我可以将所有图层冻结为:

# freeze parameters
bert = AutoModel.from_pretrained('bert-base-uncased')
for param in bert.parameters():
    param.requires_grad = False

但我想训练最后 40% 的层。当我执行len(list(bert.parameters())) 时,它给了我 199。所以让我们假设 79 是 40% 的参数。我可以这样做吗:

for param in list(bert.parameters())[-79:]: # total  trainable 199 Params: 79 is 40%
    param.requires_grad = False

我认为它会冻结前 60% 的层。

另外,谁能告诉我它会根据架构冻结哪些层?

【问题讨论】:

【参考方案1】:

您可能正在寻找named_parameters。

for name, param in bert.named_parameters():                                            
    print(name)

输出:

embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.query.bias
encoder.layer.0.attention.self.key.weight
...

named_parameters 还会显示您没有冻结前 60% 而是最后 40%:

for name, param in bert.named_parameters():
    if param.requires_grad == True:
        print(name) 

输出:

embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.query.bias
encoder.layer.0.attention.self.key.weight
encoder.layer.0.attention.self.key.bias
encoder.layer.0.attention.self.value.weight
...

您可以使用以下方法冻结前 60%:

for name, param in list(bert.named_parameters())[:-79]: 
    print('I will be frozen: '.format(name)) 
    param.requires_grad = False

【讨论】:

以上是关于使用 HuggingFace 库在 Pytorch 中训练 n% 的最后一层 BERT(训练 12 个中的最后 5 个 BERTLAYER。)的主要内容,如果未能解决你的问题,请参考以下文章

使用 huggingface pytorch-transformers GPT-2 进行分类任务

用于命名实体识别的 PyTorch Huggingface BERT-NLP

Huggingface 在情绪分析任务中给出 pytorch 索引错误

pytorch+huggingface实现基于bert模型的文本分类(附代码)

将 AllenNLP 解释与 HuggingFace 模型一起使用

huggingface/transformers快速上手