ResNeXt

Posted AI浩

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了ResNeXt相关的知识,希望对你有一定的参考价值。

import torch
import torch.nn as nn
import torchvision


class ResNeXtBlock(nn.Module):
    def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 2, cardinality=32):
        super(ResNeXtBlock,self).__init__()
        self.expansion = expansion
        self.downsampling = downsampling

        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=in_places, out_channels=places, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(places),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False, groups=cardinality),
            nn.BatchNorm2d(places),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=places, out_channels=places * self.expansion, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(places * self.expansion),
        )

        if self.downsampling:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels=in_places, out_channels=places * self.expansion, kernel_size=1, stride=stride,bias=False),
                nn.BatchNorm2d(places * self.expansion)
            )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.bottleneck(x)

        if self.downsampling:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out


if __name__ =='__main__':
    model = ResNeXtBlock(in_places=256, places=128)
    print(model)

    input = torch.randn(1,256,64,64)
    out = model(input)
    print(out.shape)

以上是关于ResNeXt的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch CIFAR10图像分类 ResNeXt篇

Pytorch CIFAR10图像分类 ResNeXt篇

[重读经典论文]ResNext

ResNeXt

深度学习——分类之ResNeXt

ResNeXt