Pytorch CNN 不学习

Posted

技术标签:

【中文标题】Pytorch CNN 不学习【英文标题】:Pytorch CNN not learning 【发布时间】:2021-04-15 23:37:17 【问题描述】:

我目前正在尝试训练 PyTorch CNN,以根据 MRI 扫描对痴呆和非痴呆个体进行分类。然而,在训练期间,模型的损失保持不变,并且在尝试区分 3 个类别时,准确率保持在 0.333。我尝试了许多受访者针对类似问题提出的建议,但没有一个对我的具体任务有效。这些建议包括更改模型中卷积单元的数量,尝试不同的损失函数,在原始数据集上训练模型,然后扩展到更大的增强图像集,以及更改学习率和批量大小等参数.我在下面附上了我的代码和输入图像示例。

图片示例

Healthy Brain

Mild Cognitive Impairment Brain

Alzheimer's Brain

预处理代码

torch.cuda.set_device(0)
g = True
if g == True:
    for f in final_MRI_data:
        path = os.path.join(final_MRI_dir, f)
        matrix = nib.load(path)
        matrix.get_fdata()
        matrix = matrix.get_fdata()
        matrix.shape
        slice_ = matrix[90, :, :]
        img = Image.fromarray(slice_)
        img = img.crop((left, top, right, bottom))
        img = ImageOps.grayscale(img)
        data_matrices.append(img)

postda_data = []
for image in data_matrices:
    for i in range(30):
        transformed_img = transforms(image)
        transformed_img = np.asarray(transformed_img)
        postda_data.append(transformed_img)

final_MRI_labels = list(itertools.chain.from_iterable(itertools.repeat(x, 30) for x in 
final_MRI_labels))

X = torch.Tensor(np.asarray([i for i in postda_data])).view(-1, 145, 200)
print(X.size())

y = torch.Tensor([i for i in final_MRI_labels]) #Target labels for cross entropy loss function

z = []
for val in final_MRI_labels:
    z.append(np.eye(3)[val])
z = torch.Tensor(np.asarray(z)) #Target one-hot encoded matrices for model testing function

网络类

class Hl_Model(nn.Module):

    torch.cuda.set_device(0)

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, 3, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2)
        self.conv4 = nn.Conv2d(128, 256, 3, stride=2)

        x = torch.randn(145,200).view(-1,1,145,200)
        self._to_linear = None
        self.convs(x)
    
        self.fc1 = nn.Linear(self._to_linear, 128, bias=True)
        self.fc2 = nn.Linear(128, 3)

    def convs(self, x):

        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(F.relu(self.conv4(x)), (2, 2), stride=2)

        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
        return x


    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)

训练函数

def train(net, train_fold_x, train_fold_y):

    optimizer = optim.Adam(net.parameters(), lr=0.05)
    BATCH_SIZE = 5
    EPOCHS = 50
    for epoch in range(EPOCHS):
        for i in tqdm(range(0, len(train_fold_x), BATCH_SIZE)):

            batch_x = train_fold_x[i:i+BATCH_SIZE].view(-1, 1, 145, 200)
            batch_y = train_fold_y[i:i+BATCH_SIZE]
        
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
            optimizer.zero_grad()
            outputs = net(batch_x)

            batch_y = batch_y.long()
            loss = loss_func(outputs, batch_y)

            loss.backward()
            optimizer.step()
        
        print(f"Epoch: epoch Loss: loss")

测试功能

def test(net, test_fold_x, test_fold_y):

    test_fold_x.to(device)
    test_fold_y.to(device)

    correct = 0
    total = 0
    with torch.no_grad():
        for i in tqdm(range(len(test_fold_x))):
            real_class = torch.argmax(test_fold_y[i]).to(device)
            net_out = net(test_fold_x[i].view(-1, 1, 145, 200).to(device))
            pred_class = torch.argmax(net_out)

            if pred_class == real_class:
                correct += 1
            total +=1

交叉验证循环

for i in range(6):
    result = next(skf.split(X, y))
    X_train = X[result[0]]
    X_test = X[result[1]]
    y_train = y[result[0]]
    y_test = z[result[1]]
    train(hl_model, X_train, y_train)
    test(hl_model, X_test, y_test)

训练期间的输出:

  0%|          | 0/188 [00:00<?, ?it/s]
  1%|          | 1/188 [00:01<05:35,  1.79s/it]
  5%|4         | 9/188 [00:01<03:45,  1.26s/it]
  9%|8         | 16/188 [00:02<02:32,  1.13it/s]
 12%|#2        | 23/188 [00:02<01:43,  1.60it/s]
 16%|#6        | 31/188 [00:02<01:09,  2.27it/s]
 21%|##        | 39/188 [00:02<00:46,  3.20it/s]
 25%|##5       | 47/188 [00:02<00:31,  4.49it/s]
 30%|##9       | 56/188 [00:02<00:21,  6.26it/s]
 35%|###4      | 65/188 [00:02<00:14,  8.67it/s]
 39%|###9      | 74/188 [00:02<00:09, 11.85it/s]
 44%|####4     | 83/188 [00:02<00:06, 15.92it/s]
 49%|####8     | 92/188 [00:02<00:04, 21.01it/s]
 54%|#####3    | 101/188 [00:03<00:03, 27.03it/s]
 59%|#####8    | 110/188 [00:03<00:02, 33.91it/s]
 63%|######3   | 119/188 [00:03<00:01, 41.21it/s]
 68%|######8   | 128/188 [00:03<00:01, 48.36it/s]
 73%|#######2  | 137/188 [00:03<00:00, 55.36it/s]
 78%|#######7  | 146/188 [00:03<00:00, 61.09it/s]
 82%|########2 | 155/188 [00:03<00:00, 65.87it/s]
 87%|########7 | 164/188 [00:03<00:00, 69.85it/s]
 92%|#########2| 173/188 [00:03<00:00, 72.93it/s]
 97%|#########6| 182/188 [00:04<00:00, 74.88it/s]
100%|##########| 188/188 [00:04<00:00, 45.32it/s]
Epoch: 0 Loss: 1.5514447689056396

  0%|          | 0/188 [00:00<?, ?it/s]
  5%|4         | 9/188 [00:00<00:02, 85.13it/s]
 10%|9         | 18/188 [00:00<00:02, 84.42it/s]
 14%|#4        | 27/188 [00:00<00:01, 83.22it/s]
 19%|#9        | 36/188 [00:00<00:01, 82.64it/s]
 24%|##3       | 45/188 [00:00<00:01, 82.23it/s]
 29%|##8       | 54/188 [00:00<00:01, 82.17it/s]
 34%|###3      | 63/188 [00:00<00:01, 82.13it/s]
 38%|###8      | 72/188 [00:00<00:01, 81.66it/s]
 43%|####2     | 80/188 [00:00<00:01, 79.76it/s]
 47%|####6     | 88/188 [00:01<00:01, 79.66it/s]
 52%|#####1    | 97/188 [00:01<00:01, 80.58it/s]
 56%|#####6    | 106/188 [00:01<00:01, 80.36it/s]
 61%|######1   | 115/188 [00:01<00:00, 80.64it/s]
 66%|######5   | 124/188 [00:01<00:00, 80.84it/s]
 71%|#######   | 133/188 [00:01<00:00, 80.54it/s]
 76%|#######5  | 142/188 [00:01<00:00, 80.98it/s]
 80%|########  | 151/188 [00:01<00:00, 80.86it/s]
 85%|########5 | 160/188 [00:01<00:00, 80.77it/s]
 90%|########9 | 169/188 [00:02<00:00, 78.81it/s]
 94%|#########4| 177/188 [00:02<00:00, 78.53it/s]
 98%|#########8| 185/188 [00:02<00:00, 77.88it/s]
100%|##########| 188/188 [00:02<00:00, 80.35it/s]
Epoch: 1 Loss: 1.5514447689056396

  0%|          | 0/188 [00:00<?, ?it/s]
  5%|4         | 9/188 [00:00<00:02, 83.56it/s]
 10%|9         | 18/188 [00:00<00:02, 82.41it/s]
 14%|#3        | 26/188 [00:00<00:01, 81.49it/s]
 19%|#8        | 35/188 [00:00<00:01, 81.65it/s]
 23%|##3       | 44/188 [00:00<00:01, 81.55it/s]
 28%|##7       | 52/188 [00:00<00:01, 80.41it/s]
 32%|###1      | 60/188 [00:00<00:01, 79.40it/s]
 37%|###6      | 69/188 [00:00<00:01, 80.17it/s]
 41%|####1     | 78/188 [00:00<00:01, 80.29it/s]
 46%|####6     | 87/188 [00:01<00:01, 80.81it/s]
 51%|#####1    | 96/188 [00:01<00:01, 80.95it/s]
 55%|#####5    | 104/188 [00:01<00:01, 80.24it/s]
 60%|######    | 113/188 [00:01<00:00, 80.56it/s]
 65%|######4   | 122/188 [00:01<00:00, 80.56it/s]
 70%|######9   | 131/188 [00:01<00:00, 80.78it/s]
 74%|#######4  | 140/188 [00:01<00:00, 79.65it/s]
 79%|#######9  | 149/188 [00:01<00:00, 80.14it/s]
 84%|########4 | 158/188 [00:01<00:00, 80.70it/s]
 89%|########8 | 167/188 [00:02<00:00, 80.88it/s]
 94%|#########3| 176/188 [00:02<00:00, 81.22it/s]
 98%|#########8| 185/188 [00:02<00:00, 81.03it/s]
100%|##########| 188/188 [00:02<00:00, 80.66it/s]
Epoch: 2 Loss: 1.5514447689056396

此输出一直重复到“Epoch: 49 Loss: 1.5514447689056396”

提前感谢您的任何建议。

【问题讨论】:

您能在训练期间提供输出吗?另外,我想看看你在这种情况下使用的损失函数的定义。 @TQCH 损失函数:loss_func = nn.CrossEntropyLoss();输出:时期:1 损失:1.5514447689056396,时期:2 损失:1.5514447689056396....时期:49 损失:1.5514447689056396 【参考方案1】:

问题似乎是由于模型前移的最后一步中的 softmax 激活以及您的损失函数 loss_func = nn.CrossEntropyLoss() 实际上取而代之的是原始 logits。请查看official documentation:

class torch.nn.CrossEntropyLoss(weight: Optional[torch.Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, reduction: str = 'mean')

此标准将 nn.LogSoftmax() 和 nn.NLLLoss() 组合在一个类中。 输入应包含每个类别的原始、非标准化分数。

【讨论】:

非常感谢。这已经解决了问题,我的模型现在正在学习。

以上是关于Pytorch CNN 不学习的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch 深度学习实践 - CNN

PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)

CNN卷积层里的多输入多输出通道channel 动手学深度学习v2 pytorch

PyTorch学习CNN手写体识别

Pytorch写CNN

三维几何学习从零开始网格上的深度学习-2:卷积网络CNN篇(Pytorch)