写残差块的时候,输入和输出大小一样/不一样 如何把他们写在同一个代码里?

Posted Tina姐

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了写残差块的时候,输入和输出大小一样/不一样 如何把他们写在同一个代码里?相关的知识,希望对你有一定的参考价值。

class full_pre_resblock(nn.Module):
    def __init__(self, filters, stride):
        super(full_pre_resblock, self).__init__()

        self.relu_bn_conv = nn.Sequential(
            relu_bn(filters),
            nn.Conv2d(in_channels=filters, out_channels=filters, stride=stride, kernel_size=3, padding=1),
            relu_bn(filters),
            nn.Conv2d(in_channels=filters, out_channels=filters, stride=1, kernel_size=3, padding=1)
        )

        if stride != 1:
            self.shortcut = nn.Conv2d(in_channels=filters, out_channels=filters, stride=stride, kernel_size=1)

    def forward(self, x):
        residual = self.shortcut(x) if hasattr(self, 'shortcut') else x
        x = self.relu_bn_conv(x)
        return residual + x

这里使用 hasattr 来判断是否需要对输入改变大小。

通过这种方式,代码变得很简洁。

以上是关于写残差块的时候,输入和输出大小一样/不一样 如何把他们写在同一个代码里?的主要内容,如果未能解决你的问题,请参考以下文章

自说自话1

Gym - 101550A(Artwork 倒序+并查集)

Ebay 面试题 | 把数组分成和大小一样的集合

神经网络 残差网络

微信朋友圈字体大小不一样怎么弄得

word编辑的时候,前后两页的纸张大小要设置成不一样的,怎么办