加载和冻结预训练模型以与新网络结合

Posted

技术标签:

【中文标题】加载和冻结预训练模型以与新网络结合【英文标题】:Loading & Freezing a Pretrained Model to Combine with a New Network 【发布时间】:2020-06-21 00:49:25 【问题描述】:

我有一个预训练模型,并希望在其之上构建一个分类器。我正在尝试加载和冻结预训练模型的权重,并将其输出传递给我想要优化的新分类器。这是我到目前为止所拥有的,我有点卡在nn.Sequential 行中的TypeError: forward() missing 1 required positional argument: 'x' 错误:

import model #model.py contains the architecture of the pretrained model

class Classifier(nn.Module):
    def __init__(self):
        ...
    def forward(self, x):
        ...

net = model.Model()
net.load_state_dict(checkpoint["net"])

for c in net.children():
    for param in child.parameters():
        params.requires_grad = False

model = nn.Sequential(nn.ModuleList(net()), Classifier())

【问题讨论】:

【参考方案1】:

TL;DR

model = nn.Sequential(nn.ModuleList(net), Classifier())

您正在通过net()“调用”net.forward,而不是Classifier()Classifier__init__ 方法

【讨论】:

【参考方案2】:

在与 PyTorch 论坛的 @ptrblck 讨论后,我终于解决了这个问题。该解决方案类似于 Shai 的答案,只是因为net 包含model.Model 类的一个实例,所以应该改为使用model = nn.Sequential(net, Classifier()),而不是调用nn.ModuleList()

【讨论】:

以上是关于加载和冻结预训练模型以与新网络结合的主要内容,如果未能解决你的问题,请参考以下文章

由于内存问题,如何保存与预训练的 bert 模型的分类器层相关的参数?

Keras-在预训练好网络模型上进行fine-tune

pytorch加载内置模型、修改网络结构及加载预训练参数

pytorch如何给预训练模型添加新的层

代码补全快餐教程 - 预训练模型的加载和使用

pytorch中修改后的模型如何加载预训练模型