写残差块的时候,输入和输出大小一样/不一样 如何把他们写在同一个代码里?
Posted Tina姐
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了写残差块的时候,输入和输出大小一样/不一样 如何把他们写在同一个代码里?相关的知识,希望对你有一定的参考价值。
class full_pre_resblock(nn.Module):
def __init__(self, filters, stride):
super(full_pre_resblock, self).__init__()
self.relu_bn_conv = nn.Sequential(
relu_bn(filters),
nn.Conv2d(in_channels=filters, out_channels=filters, stride=stride, kernel_size=3, padding=1),
relu_bn(filters),
nn.Conv2d(in_channels=filters, out_channels=filters, stride=1, kernel_size=3, padding=1)
)
if stride != 1:
self.shortcut = nn.Conv2d(in_channels=filters, out_channels=filters, stride=stride, kernel_size=1)
def forward(self, x):
residual = self.shortcut(x) if hasattr(self, 'shortcut') else x
x = self.relu_bn_conv(x)
return residual + x
这里使用 hasattr
来判断是否需要对输入改变大小。
通过这种方式,代码变得很简洁。
以上是关于写残差块的时候,输入和输出大小一样/不一样 如何把他们写在同一个代码里?的主要内容,如果未能解决你的问题,请参考以下文章