PyTorch实现InceptionV1模块
Posted 算法与编程之美
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch实现InceptionV1模块相关的知识,希望对你有一定的参考价值。
问题
方法
'''
InceptionV1模块
'''
class Inception(nn.Module):
def __init__(self, in_channels, ch1x1,
ch3x3red, ch3x3,
ch5x5red, ch5x5,
pool_proj):
super(Inception, self).__init__()
self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
self.branch2 = nn.Sequential(
BasicConv2d(in_channels, ch3x3red, kernel_size=1),
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
)
self.branch3 = nn.Sequential(
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
BasicConv2d(in_channels, pool_proj, kernel_size=1)
)
def forward(self, x):
b1 = self.branch1(x)
b2 = self.branch2(x)
b3 = self.branch3(x)
b4 = self.branch4(x)
return torch.cat([b1, b2, b3, b4], dim=1)
结语
以上是关于PyTorch实现InceptionV1模块的主要内容,如果未能解决你的问题,请参考以下文章