PyTorch两种不同分类层的设计方法
Posted 算法与编程之美
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch两种不同分类层的设计方法相关的知识,希望对你有一定的参考价值。
问题
涉及到图像分类的网络的最后一层分类层,有两种实现方法,如下所示,你更偏向于哪种方法呢?
方法
方法1
import torch
from torch import nn
'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 32, 3, padding=1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(32, 2)
def forward(self, x):
x = self.conv(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1) # 展开所有元素
out = self.classifier(x)
return out
if __name__ == '__main__':
from torchsummary import summary
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.rand(size=(1, 3, 7, 7)).to(device)
net = MyNet().to(device)
summary(net, (3, 7, 7))
方法2
import torch
from torch import nn
'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 32, 3, padding=1)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(32, 2)
)
def forward(self, x):
x = self.conv(x)
out = self.classifier(x)
return out
if __name__ == '__main__':
from torchsummary import summary
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.rand(size=(1, 3, 7, 7)).to(device)
net = MyNet().to(device)
out = net(x)
summary(net, (3, 7, 7))
结语
从扩展性、可读性的角度来说,更偏向于方法2的设计。
以上是关于PyTorch两种不同分类层的设计方法的主要内容,如果未能解决你的问题,请参考以下文章