PyTorch两种不同分类层的设计方法

Posted 算法与编程之美

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch两种不同分类层的设计方法相关的知识,希望对你有一定的参考价值。

问题

涉及到图像分类的网络的最后一层分类层,有两种实现方法,如下所示,你更偏向于哪种方法呢?

方法

方法1

import torch
from torch import nn


'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        self.conv = nn.Conv2d(3, 32, 3, padding=1)
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(32, 2)
        
    
    def forward(self, x):
        x = self.conv(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1) # 展开所有元素
        out = self.classifier(x)
        
        return out
    
if __name__ == '__main__':

    from torchsummary import summary
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
    x = torch.rand(size=(1, 3, 7, 7)).to(device)
    net = MyNet().to(device)
    
    summary(net, (3, 7, 7))
    


方法2

import torch
from torch import nn


'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        self.conv = nn.Conv2d(3, 32, 3, padding=1)
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, 2)
        )
        
    
    def forward(self, x):
        x = self.conv(x)
        out = self.classifier(x)
        
        return out
    
if __name__ == '__main__':

    from torchsummary import summary
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
    x = torch.rand(size=(1, 3, 7, 7)).to(device)
    net = MyNet().to(device)
    out = net(x)
    
    summary(net, (3, 7, 7))
    


结语

从扩展性、可读性的角度来说,更偏向于方法2的设计。

以上是关于PyTorch两种不同分类层的设计方法的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch:将 VGG 模型转换为顺序模型,但得到不同的输出

如何在 Pytorch 中应用分层学习率?

pytorch 图像分类

最后一层的张量输出在 PyTorch 中的形状错误

动手学习pytorch——多层感知机

基于Pytorch实现的声音分类