RuntimeError: 给定组=1,大小为 [32, 3, 16, 16, 16] 的权重,预期输入 [100, 16, 16, 16, 3] 有 3 个通道,但有 16 个通道

Posted

技术标签:

【中文标题】RuntimeError: 给定组=1,大小为 [32, 3, 16, 16, 16] 的权重,预期输入 [100, 16, 16, 16, 3] 有 3 个通道,但有 16 个通道【英文标题】:RuntimeError: Given groups=1, weight of size [32, 3, 16, 16, 16], expected input[100, 16, 16, 16, 3] to have 3 channels, but got 16 channels instead 【发布时间】:2020-10-06 13:16:25 【问题描述】:

这是我认为问题所在的代码部分。

def __init__(self):
        super(Lightning_CNNModel, self).__init__()

        self.conv_layer1 = self._conv_layer_set(3, 32)
        self.conv_layer2 = self._conv_layer_set(32, 64)
        self.fc1 = nn.Linear(2**3*64, 128)
        self.fc2 = nn.Linear(128, 10)   # num_classes = 10
        self.relu = nn.LeakyReLU()
        self.batch=nn.BatchNorm1d(128)
        self.drop=nn.Dropout(p=0.15)

    def _conv_layer_set(self, in_c, out_c):
        conv_layer = nn.Sequential(
            nn.Conv3d(in_c, out_c, kernel_size=(3, 3, 3), padding=0),
            nn.LeakyReLU(),
            nn.MaxPool3d((2, 2, 2)),
        )

        return conv_layer



    def forward(self, x):
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.batch(out)
        out = self.drop(out)
        out = self.fc2(out)

        return out

这是我正在处理的代码

【问题讨论】:

【参考方案1】:

nn.Conv3d 期望输入的大小为 [batch_size, channels, depth, height, width]。第一个卷积需要 3 个通道,但您的输入大小为 [100, 16, 16, 16, 3],这将是 16 个通道。

假设你的数据是[batch_size, depth, height, width, channels],你需要交换尺寸,这可以用torch.Tensor.permute来完成:

# From: [batch_size, depth, height, width, channels]
# To: [batch_size, channels, depth, height, width]
input = input.permute(0, 4, 1, 2, 3)

【讨论】:

非常感谢您的帮助。我明白了。

以上是关于RuntimeError: 给定组=1,大小为 [32, 3, 16, 16, 16] 的权重,预期输入 [100, 16, 16, 16, 3] 有 3 个通道,但有 16 个通道的主要内容,如果未能解决你的问题,请参考以下文章

RuntimeError: Expected hidden[0] size (1, 1, 512), got (1, 128, 512) for LSTM pytorch

c_cpp 以给定大小的组反转链接列表设置1

c_cpp GFG以给定大小的组反转数组

c_cpp 以给定大小的组反转链接列表设置2

RuntimeError:字典迭代过程中改变大小 - 如何解决?

python 报错RuntimeError: dictionary changed size during iteration