无法在 pytorch 中替换 Densenet121 上的分类器

Posted

技术标签:

【中文标题】无法在 pytorch 中替换 Densenet121 上的分类器【英文标题】:Can't replace classifier on Densenet121 in pytorch 【发布时间】:2020-01-08 10:46:03 【问题描述】:

我正在尝试使用这个 github DenseNet121 模型 (https://github.com/gaetandi/cheXpert.git) 进行一些迁移学习。我遇到了将分类层从 14 个输出调整为 2 个输出的问题。

github代码的相关部分是:

class DenseNet121(nn.Module):
    """Model modified.
    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.
    """
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )
def forward(self, x):
    x = self.densenet121(x)
    return x

我加载并初始化:

# initialize and load the model
model = DenseNet121(nnClassCount).cuda()
model = torch.nn.DataParallel(model).cuda()
modeldict = torch.load("model_ones_3epoch_densenet.tar")
model.load_state_dict(modeldict['state_dict'])

DenseNet 似乎没有将层拆分为子层,因此 model = nn.Sequential(*list(modelRes.children())[:-1]) 将无法工作。

model.classifier = nn.Linear(1024, 2) 似乎可以在默认的 DenseNets 上工作,但使用修改后的分类器(附加 sigmoid 函数),它最终只是添加了一个额外的分类器层,而不替换原来的。

我试过了

model.classifier = nn.Sequential(
    nn.Linear(1024, dset_classes_number), 
    nn.Sigmoid()
)

但是我有相同的添加而不是替换分类器问题:

...
      )
      (classifier): Sequential(
        (0): Linear(in_features=1024, out_features=14, bias=True)
        (1): Sigmoid()
      )
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=2, bias=True)
    (1): Sigmoid()
  )
)

【问题讨论】:

你为什么说self.densenet121.classifier = nn.Sequential(...)不起作用?为何如此?你得到什么错误? 没有错误,只是一个重复的分类器。我更新了问题以向您展示我在 print(model) 后得到的结果 您只需更换分类器一次。内幕init你做对了,不需要再做 我想在更改分类器层之前从预训练模型中加载 state_dict。不用我用14个输出初始化,加载权重,然后改分类器吗? 这正是你在 init 中所做的 【参考方案1】:

如果您想替换 densenet121 中的 classifier ,它是您的 model 的成员,您需要分配

model.densenet121.classifier = nn.Sequential(...)

【讨论】:

这给了我AttributeError: 'DataParallel' object has no attribute 'densenet121' 稍作调整使其工作:model.module.densenet121.classifier = nn.Sequential(...)

以上是关于无法在 pytorch 中替换 Densenet121 上的分类器的主要内容,如果未能解决你的问题,请参考以下文章

如何用其他 pytorch 函数替换 torch.sparse?

经验分享使用 Rp 类对 pytorch 算子作替换操作

Pytorch 复制替换网络中的部分参数,网络参数的定向赋值

我无法在 jupyter 和 Spyder 中安装 pytorch?

无法在Pytorch中使用GPU

无法在pytorch python中使用多目标损失函数