加载和冻结预训练模型以与新网络结合
Posted
技术标签:
【中文标题】加载和冻结预训练模型以与新网络结合【英文标题】:Loading & Freezing a Pretrained Model to Combine with a New Network 【发布时间】:2020-06-21 00:49:25 【问题描述】:我有一个预训练模型,并希望在其之上构建一个分类器。我正在尝试加载和冻结预训练模型的权重,并将其输出传递给我想要优化的新分类器。这是我到目前为止所拥有的,我有点卡在nn.Sequential
行中的TypeError: forward() missing 1 required positional argument: 'x'
错误:
import model #model.py contains the architecture of the pretrained model
class Classifier(nn.Module):
def __init__(self):
...
def forward(self, x):
...
net = model.Model()
net.load_state_dict(checkpoint["net"])
for c in net.children():
for param in child.parameters():
params.requires_grad = False
model = nn.Sequential(nn.ModuleList(net()), Classifier())
【问题讨论】:
【参考方案1】:TL;DR
model = nn.Sequential(nn.ModuleList(net), Classifier())
您正在通过net()
“调用”net.forward
,而不是Classifier()
中Classifier
的__init__
方法类。
【讨论】:
【参考方案2】:在与 PyTorch 论坛的 @ptrblck 讨论后,我终于解决了这个问题。该解决方案类似于 Shai 的答案,只是因为net
包含model.Model
类的一个实例,所以应该改为使用model = nn.Sequential(net, Classifier())
,而不是调用nn.ModuleList()
。
【讨论】:
以上是关于加载和冻结预训练模型以与新网络结合的主要内容,如果未能解决你的问题,请参考以下文章