Pytorch CNN 不学习
Posted
技术标签:
【中文标题】Pytorch CNN 不学习【英文标题】:Pytorch CNN not learning 【发布时间】:2021-04-15 23:37:17 【问题描述】:我目前正在尝试训练 PyTorch CNN,以根据 MRI 扫描对痴呆和非痴呆个体进行分类。然而,在训练期间,模型的损失保持不变,并且在尝试区分 3 个类别时,准确率保持在 0.333。我尝试了许多受访者针对类似问题提出的建议,但没有一个对我的具体任务有效。这些建议包括更改模型中卷积单元的数量,尝试不同的损失函数,在原始数据集上训练模型,然后扩展到更大的增强图像集,以及更改学习率和批量大小等参数.我在下面附上了我的代码和输入图像示例。
图片示例
Healthy Brain
Mild Cognitive Impairment Brain
Alzheimer's Brain
预处理代码
torch.cuda.set_device(0)
g = True
if g == True:
for f in final_MRI_data:
path = os.path.join(final_MRI_dir, f)
matrix = nib.load(path)
matrix.get_fdata()
matrix = matrix.get_fdata()
matrix.shape
slice_ = matrix[90, :, :]
img = Image.fromarray(slice_)
img = img.crop((left, top, right, bottom))
img = ImageOps.grayscale(img)
data_matrices.append(img)
postda_data = []
for image in data_matrices:
for i in range(30):
transformed_img = transforms(image)
transformed_img = np.asarray(transformed_img)
postda_data.append(transformed_img)
final_MRI_labels = list(itertools.chain.from_iterable(itertools.repeat(x, 30) for x in
final_MRI_labels))
X = torch.Tensor(np.asarray([i for i in postda_data])).view(-1, 145, 200)
print(X.size())
y = torch.Tensor([i for i in final_MRI_labels]) #Target labels for cross entropy loss function
z = []
for val in final_MRI_labels:
z.append(np.eye(3)[val])
z = torch.Tensor(np.asarray(z)) #Target one-hot encoded matrices for model testing function
网络类
class Hl_Model(nn.Module):
torch.cuda.set_device(0)
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, stride=2)
self.conv2 = nn.Conv2d(32, 64, 3, stride=2)
self.conv3 = nn.Conv2d(64, 128, 3, stride=2)
self.conv4 = nn.Conv2d(128, 256, 3, stride=2)
x = torch.randn(145,200).view(-1,1,145,200)
self._to_linear = None
self.convs(x)
self.fc1 = nn.Linear(self._to_linear, 128, bias=True)
self.fc2 = nn.Linear(128, 3)
def convs(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.max_pool2d(F.relu(self.conv4(x)), (2, 2), stride=2)
if self._to_linear is None:
self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
return x
def forward(self, x):
x = self.convs(x)
x = x.view(-1, self._to_linear)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=1)
训练函数
def train(net, train_fold_x, train_fold_y):
optimizer = optim.Adam(net.parameters(), lr=0.05)
BATCH_SIZE = 5
EPOCHS = 50
for epoch in range(EPOCHS):
for i in tqdm(range(0, len(train_fold_x), BATCH_SIZE)):
batch_x = train_fold_x[i:i+BATCH_SIZE].view(-1, 1, 145, 200)
batch_y = train_fold_y[i:i+BATCH_SIZE]
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
outputs = net(batch_x)
batch_y = batch_y.long()
loss = loss_func(outputs, batch_y)
loss.backward()
optimizer.step()
print(f"Epoch: epoch Loss: loss")
测试功能
def test(net, test_fold_x, test_fold_y):
test_fold_x.to(device)
test_fold_y.to(device)
correct = 0
total = 0
with torch.no_grad():
for i in tqdm(range(len(test_fold_x))):
real_class = torch.argmax(test_fold_y[i]).to(device)
net_out = net(test_fold_x[i].view(-1, 1, 145, 200).to(device))
pred_class = torch.argmax(net_out)
if pred_class == real_class:
correct += 1
total +=1
交叉验证循环
for i in range(6):
result = next(skf.split(X, y))
X_train = X[result[0]]
X_test = X[result[1]]
y_train = y[result[0]]
y_test = z[result[1]]
train(hl_model, X_train, y_train)
test(hl_model, X_test, y_test)
训练期间的输出:
0%| | 0/188 [00:00<?, ?it/s]
1%| | 1/188 [00:01<05:35, 1.79s/it]
5%|4 | 9/188 [00:01<03:45, 1.26s/it]
9%|8 | 16/188 [00:02<02:32, 1.13it/s]
12%|#2 | 23/188 [00:02<01:43, 1.60it/s]
16%|#6 | 31/188 [00:02<01:09, 2.27it/s]
21%|## | 39/188 [00:02<00:46, 3.20it/s]
25%|##5 | 47/188 [00:02<00:31, 4.49it/s]
30%|##9 | 56/188 [00:02<00:21, 6.26it/s]
35%|###4 | 65/188 [00:02<00:14, 8.67it/s]
39%|###9 | 74/188 [00:02<00:09, 11.85it/s]
44%|####4 | 83/188 [00:02<00:06, 15.92it/s]
49%|####8 | 92/188 [00:02<00:04, 21.01it/s]
54%|#####3 | 101/188 [00:03<00:03, 27.03it/s]
59%|#####8 | 110/188 [00:03<00:02, 33.91it/s]
63%|######3 | 119/188 [00:03<00:01, 41.21it/s]
68%|######8 | 128/188 [00:03<00:01, 48.36it/s]
73%|#######2 | 137/188 [00:03<00:00, 55.36it/s]
78%|#######7 | 146/188 [00:03<00:00, 61.09it/s]
82%|########2 | 155/188 [00:03<00:00, 65.87it/s]
87%|########7 | 164/188 [00:03<00:00, 69.85it/s]
92%|#########2| 173/188 [00:03<00:00, 72.93it/s]
97%|#########6| 182/188 [00:04<00:00, 74.88it/s]
100%|##########| 188/188 [00:04<00:00, 45.32it/s]
Epoch: 0 Loss: 1.5514447689056396
0%| | 0/188 [00:00<?, ?it/s]
5%|4 | 9/188 [00:00<00:02, 85.13it/s]
10%|9 | 18/188 [00:00<00:02, 84.42it/s]
14%|#4 | 27/188 [00:00<00:01, 83.22it/s]
19%|#9 | 36/188 [00:00<00:01, 82.64it/s]
24%|##3 | 45/188 [00:00<00:01, 82.23it/s]
29%|##8 | 54/188 [00:00<00:01, 82.17it/s]
34%|###3 | 63/188 [00:00<00:01, 82.13it/s]
38%|###8 | 72/188 [00:00<00:01, 81.66it/s]
43%|####2 | 80/188 [00:00<00:01, 79.76it/s]
47%|####6 | 88/188 [00:01<00:01, 79.66it/s]
52%|#####1 | 97/188 [00:01<00:01, 80.58it/s]
56%|#####6 | 106/188 [00:01<00:01, 80.36it/s]
61%|######1 | 115/188 [00:01<00:00, 80.64it/s]
66%|######5 | 124/188 [00:01<00:00, 80.84it/s]
71%|####### | 133/188 [00:01<00:00, 80.54it/s]
76%|#######5 | 142/188 [00:01<00:00, 80.98it/s]
80%|######## | 151/188 [00:01<00:00, 80.86it/s]
85%|########5 | 160/188 [00:01<00:00, 80.77it/s]
90%|########9 | 169/188 [00:02<00:00, 78.81it/s]
94%|#########4| 177/188 [00:02<00:00, 78.53it/s]
98%|#########8| 185/188 [00:02<00:00, 77.88it/s]
100%|##########| 188/188 [00:02<00:00, 80.35it/s]
Epoch: 1 Loss: 1.5514447689056396
0%| | 0/188 [00:00<?, ?it/s]
5%|4 | 9/188 [00:00<00:02, 83.56it/s]
10%|9 | 18/188 [00:00<00:02, 82.41it/s]
14%|#3 | 26/188 [00:00<00:01, 81.49it/s]
19%|#8 | 35/188 [00:00<00:01, 81.65it/s]
23%|##3 | 44/188 [00:00<00:01, 81.55it/s]
28%|##7 | 52/188 [00:00<00:01, 80.41it/s]
32%|###1 | 60/188 [00:00<00:01, 79.40it/s]
37%|###6 | 69/188 [00:00<00:01, 80.17it/s]
41%|####1 | 78/188 [00:00<00:01, 80.29it/s]
46%|####6 | 87/188 [00:01<00:01, 80.81it/s]
51%|#####1 | 96/188 [00:01<00:01, 80.95it/s]
55%|#####5 | 104/188 [00:01<00:01, 80.24it/s]
60%|###### | 113/188 [00:01<00:00, 80.56it/s]
65%|######4 | 122/188 [00:01<00:00, 80.56it/s]
70%|######9 | 131/188 [00:01<00:00, 80.78it/s]
74%|#######4 | 140/188 [00:01<00:00, 79.65it/s]
79%|#######9 | 149/188 [00:01<00:00, 80.14it/s]
84%|########4 | 158/188 [00:01<00:00, 80.70it/s]
89%|########8 | 167/188 [00:02<00:00, 80.88it/s]
94%|#########3| 176/188 [00:02<00:00, 81.22it/s]
98%|#########8| 185/188 [00:02<00:00, 81.03it/s]
100%|##########| 188/188 [00:02<00:00, 80.66it/s]
Epoch: 2 Loss: 1.5514447689056396
此输出一直重复到“Epoch: 49 Loss: 1.5514447689056396”
提前感谢您的任何建议。
【问题讨论】:
您能在训练期间提供输出吗?另外,我想看看你在这种情况下使用的损失函数的定义。 @TQCH 损失函数:loss_func = nn.CrossEntropyLoss();输出:时期:1 损失:1.5514447689056396,时期:2 损失:1.5514447689056396....时期:49 损失:1.5514447689056396 【参考方案1】:问题似乎是由于模型前移的最后一步中的 softmax 激活以及您的损失函数 loss_func = nn.CrossEntropyLoss() 实际上取而代之的是原始 logits。请查看official documentation:
class torch.nn.CrossEntropyLoss(weight: Optional[torch.Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, reduction: str = 'mean')
和
此标准将 nn.LogSoftmax() 和 nn.NLLLoss() 组合在一个类中。 输入应包含每个类别的原始、非标准化分数。
【讨论】:
非常感谢。这已经解决了问题,我的模型现在正在学习。以上是关于Pytorch CNN 不学习的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch基于CNN的手写数字识别(在MNIST数据集上训练)