无法在 pytorch 中替换 Densenet121 上的分类器
Posted
技术标签:
【中文标题】无法在 pytorch 中替换 Densenet121 上的分类器【英文标题】:Can't replace classifier on Densenet121 in pytorch 【发布时间】:2020-01-08 10:46:03 【问题描述】:我正在尝试使用这个 github DenseNet121 模型 (https://github.com/gaetandi/cheXpert.git) 进行一些迁移学习。我遇到了将分类层从 14 个输出调整为 2 个输出的问题。
github代码的相关部分是:
class DenseNet121(nn.Module):
"""Model modified.
The architecture of our model is the same as standard DenseNet121
except the classifier layer which has an additional sigmoid function.
"""
def __init__(self, out_size):
super(DenseNet121, self).__init__()
self.densenet121 = torchvision.models.densenet121(pretrained=True)
num_ftrs = self.densenet121.classifier.in_features
self.densenet121.classifier = nn.Sequential(
nn.Linear(num_ftrs, out_size),
nn.Sigmoid()
)
def forward(self, x):
x = self.densenet121(x)
return x
我加载并初始化:
# initialize and load the model
model = DenseNet121(nnClassCount).cuda()
model = torch.nn.DataParallel(model).cuda()
modeldict = torch.load("model_ones_3epoch_densenet.tar")
model.load_state_dict(modeldict['state_dict'])
DenseNet 似乎没有将层拆分为子层,因此 model = nn.Sequential(*list(modelRes.children())[:-1])
将无法工作。
model.classifier = nn.Linear(1024, 2)
似乎可以在默认的 DenseNets 上工作,但使用修改后的分类器(附加 sigmoid 函数),它最终只是添加了一个额外的分类器层,而不替换原来的。
我试过了
model.classifier = nn.Sequential(
nn.Linear(1024, dset_classes_number),
nn.Sigmoid()
)
但是我有相同的添加而不是替换分类器问题:
...
)
(classifier): Sequential(
(0): Linear(in_features=1024, out_features=14, bias=True)
(1): Sigmoid()
)
)
)
(classifier): Sequential(
(0): Linear(in_features=1024, out_features=2, bias=True)
(1): Sigmoid()
)
)
【问题讨论】:
你为什么说self.densenet121.classifier = nn.Sequential(...)
不起作用?为何如此?你得到什么错误?
没有错误,只是一个重复的分类器。我更新了问题以向您展示我在 print(model) 后得到的结果
您只需更换分类器一次。内幕init你做对了,不需要再做
我想在更改分类器层之前从预训练模型中加载 state_dict。不用我用14个输出初始化,加载权重,然后改分类器吗?
这正是你在 init 中所做的
【参考方案1】:
如果您想替换 densenet121
中的 classifier
,它是您的 model
的成员,您需要分配
model.densenet121.classifier = nn.Sequential(...)
【讨论】:
这给了我AttributeError: 'DataParallel' object has no attribute 'densenet121'
稍作调整使其工作:model.module.densenet121.classifier = nn.Sequential(...)
以上是关于无法在 pytorch 中替换 Densenet121 上的分类器的主要内容,如果未能解决你的问题,请参考以下文章
如何用其他 pytorch 函数替换 torch.sparse?
Pytorch 复制替换网络中的部分参数,网络参数的定向赋值