如何加载部分预训练的 pytorch 模型?

Posted

技术标签:

【中文标题】如何加载部分预训练的 pytorch 模型?【英文标题】:How can I load a partial pretrained pytorch model? 【发布时间】:2020-07-27 10:00:01 【问题描述】:

我正在尝试让 pytorch 模型在句子分类任务上运行。当我处理医学笔记时,我正在使用 ClinicalBert (https://github.com/kexinhuang12345/clinicalBERT) 并希望使用其预训练的权重。不幸的是,ClinicalBert 模型仅将文本分类为 1 个二进制标签,而我有 281 个二进制标签。因此,我试图实现此代码https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb,其中 bert 之后的最终分类器为 281 长。

如何在不加载分类权重的情况下从 ClinicalBert 模型加载预训练的 Bert 权重?

天真地尝试从预训练的 ClinicalBert 权重中加载权重,我收到以下错误:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

我目前尝试替换 pytorch_pretrained_bert 包中的 from_pretrained 函数,并像这样弹出分类器权重和偏差:

def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
    ...
    if state_dict is None:
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path, map_location='cpu')
    state_dict.pop('classifier.weight')
    state_dict.pop('classifier.bias')
    old_keys = []
    new_keys = []
    ...

我收到以下错误消息: 信息 - 建模诊断 - BertForMultiLabelSequenceClassification 的权重未从预训练模型初始化:['classifier.weight', 'classifier.bias']

最后我想从clinicalBert预训练权重加载bert嵌入,并随机初始化***分类器权重。

【问题讨论】:

【参考方案1】:

在加载之前删除状态字典中的键是一个好的开始。假设您使用nn.Module.load_state_dict 加载预训练的权重,那么您还需要设置strict=False 参数以避免意外或丢失键导致的错误。这将忽略 state_dict 中不存在于模型中的条目(意外键),并且对您而言更重要的是,会将缺失的条目保留为默认初始化(缺失键)。为安全起见,您可以检查方法的返回值,以验证相关权重是否属于缺失键的一部分,并且没有任何意外键。

【讨论】:

以上是关于如何加载部分预训练的 pytorch 模型?的主要内容,如果未能解决你的问题,请参考以下文章

预训练模型的加载机理pytorch版

pytorch中修改后的模型如何加载预训练模型

加载Pytorch中的预训练模型及部分结构的导入

加载Pytorch中的预训练模型及部分结构的导入

pytorch加载内置模型、修改网络结构及加载预训练参数

PyTorch学习笔记 5.torchvision库