自定义 CNN 给出错误的输出形状

Posted

技术标签:

【中文标题】自定义 CNN 给出错误的输出形状【英文标题】:Custom CNN gives wrong output shape 【发布时间】:2020-12-18 03:08:05 【问题描述】:

我需要一些帮助。我正在尝试制作一个自定义 CNN,它应该接受一个通道图像并进行二进制分类。这是模型:

class custom_small_CNN(nn.Module):

    def __init__(self, input_channels=1, output_features=1):
        super(custom_small_CNN, self).__init__()

        self.input_channels = input_channels
        self.output_features = output_features

        self.conv1 = nn.Conv2d(self.input_channels, 8, kernel_size=(7, 7), stride=(2, 2), padding=(6, 6), dilation=(2, 2))
        self.conv2 = nn.Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1))
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))
        self.fc1 = nn.Linear(in_features=1024, out_features=self.output_features, bias=True)
        self.dropout = nn.Dropout(p=0.5)
        self.softmax = nn.Softmax(dim=1)
        self.net_name = 'Custom_Small_CNN'

        self.net = nn.Sequential(self.conv1, self.pool, self.conv2, self.pool, self.fc1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        #x = self.dropout(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.view(-1, 1024)
        x = self.dropout(x)
        x = self.fc1(x)
        if not self.output_features == 1:
            x = self.softmax(x)
        return x

但是,当我像这样在模型中放置一个包含 4 个图像(全为零)的示例批次时:

x = torch.from_numpy(np.zeros((4, 1, 256, 256))).float()
net = custom_small_CNN(output_features=2, input_channels=1).float()
output = net(x)

输出的形状为torch.Size([16, 2]) 而不是torch.Size([4, 2]),这是我想要的,例如ResNet 作为输出提供。我错过了什么? 谢谢!

【问题讨论】:

【参考方案1】:

当你应用池化层时,它会返回 (batch_size, 2, 2, num_filters),所以当你 reshape x = x.view(-1, 1024) 时,它会导致 (batch_size * 4, num_filters) 为形状。

您应该扁平化或平均池化层的输出,而不是像那样重塑。扁平化在这里最常用。

所以,替换下面的行

x = x.view(-1, 1024)

x = nn.Flatten()(x)

会产生正确的最终输出形状

【讨论】:

以上是关于自定义 CNN 给出错误的输出形状的主要内容,如果未能解决你的问题,请参考以下文章

在自定义数据集上微调 MobileNet 时出现形状错误

在 CNN 的 keras 自定义损失函数中操作数据

Keras 自定义损失函数 dtype 错误

通过自定义 LSTM 时的形状错误

在 Keras 自定义层中连接多个形状为 (None, m) 的 LSTM 输出

在 pytorch 中为 CNN 设置自定义内核