使用 torch.save 和 torch.load 继续培训 - 关键错误消息

Posted

技术标签:

【中文标题】使用 torch.save 和 torch.load 继续培训 - 关键错误消息【英文标题】:Continue training with torch.save and torch.load - key error messages 【发布时间】:2021-11-29 03:57:58 【问题描述】:

我是 Torch 的新手,我使用的是蒙版 cnn 模型的代码模板。为了在培训中断时做好准备,我在代码中使用了 torch.save 和 torch.load,但我认为我不能单独使用它来继续培训课程?我开始训练:

model = train_mask_net(64)

这调用了函数 train_mask_net,其中我在 epoch 循环中包含了 torch.save。我想加载其中一个保存的模型并在循环前使用 torch.load 继续训练,但我收到了优化器、损失和 epoch 调用的“关键错误”消息。我应该像在一些教程中看到的那样创建一个特定的检查点功能,还是有可能我可以使用 torch.saved 命令保存的文件继续训练?

def train_mask_net(num_epochs=1):
    data = MaskDataset(list(data_mask.keys()))
    data_loader = torch.utils.data.DataLoader(data, batch_size=8, shuffle=True, num_workers=4)

    model = XceptionHourglass(max_clz+2)
    model.cuda()
    dp = torch.nn.DataParallel(model)
    loss = nn.CrossEntropyLoss()

    params = [p for p in dp.parameters() if p.requires_grad]
    optimizer = torch.optim.RMSprop(params, lr=2.5e-4,  momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=6,
                                                   gamma=0.9)
    
    checkpoint = torch.load('imaterialist2020-pretrain-models/maskmodel_160.model_ep17')
    #print(checkpoint)
    model.load_state_dict(checkpoint)
    #optimizer.load_state_dict(checkpoint)
    #epoch = checkpoint['epoch']
    #loss = checkpoint['loss']
    
    for epoch in range(num_epochs):
        print(epoch)
        total_loss = []
        prog = tqdm(data_loader, total=len(data_loader))
        for i, (imag, mask) in enumerate(prog):
            X = imag.cuda()
            y = mask.cuda()
            xx = dp(X)
            # to 1D-array
            y = y.reshape((y.size(0),-1))  # batch, flatten-img
            y = y.reshape((y.size(0) * y.size(1),))  # flatten-all
            xx = xx.reshape((xx.size(0), xx.size(1), -1))  # batch, channel, flatten-img
            xx = torch.transpose(xx, 2, 1)  # batch, flatten-img, channel
            xx = xx.reshape((xx.size(0) * xx.size(1),-1))  # flatten-all, channel

            losses = loss(xx, y)

            prog.set_description("loss:%05f"%losses)
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            total_loss.append(losses.detach().cpu().numpy())
            torch.save(model.state_dict(), MODEL_FILE_DIR+"maskmodel_%d.model"%attr_image_size[0]+'_ep'+str(epoch)+'_tsave')

        prog, X, xx, y, losses = None, None, None, None, None,
        torch.cuda.empty_cache()
        gc.collect()
    return model

我认为没有必要,但 xceptionhour 类看起来像这样:

class XceptionHourglass(nn.Module):
    def __init__(self, num_classes):
        super(XceptionHourglass, self).__init__()
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3, 128, 3, 2, 1, bias=True)
        self.bn1 = nn.BatchNorm2d(128)
        self.mish = Mish()

        self.conv2 = nn.Conv2d(128, 256, 3, 1, 1, bias=True)
        self.bn2 = nn.BatchNorm2d(256)

        self.block1 = HourglassNet(4, 256)
        self.bn3 = nn.BatchNorm2d(256)
        self.block2 = HourglassNet(4, 256)
...

【问题讨论】:

还有一个问题:我的检查点在for循环前面的位置是否正确才能继续训练? 【参考方案1】:

torch.save(model.state_dict(), PATH) 只保存模型权重。

要同时保存优化器、损失、纪元等,将其更改为:

torch.save('model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': loss,
            'epoch': epoch,
            # ...
            , PATH)

要加载它们:

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

更多信息here。

【讨论】:

thx,因为我没有包括优化器等。根本没有机会使用我原来的 torch.save 命令保存的对象继续训练?我问是因为培训需要很长时间,否则我想我需要重新开始 我知道我的代码在哪个epoch中断了,但是我不知道那个时候损失函数的值。也许有可能。

以上是关于使用 torch.save 和 torch.load 继续培训 - 关键错误消息的主要内容,如果未能解决你的问题,请参考以下文章

将火炬模型(torch.save)转换为可以用基本 Python 处理的矩阵公式

Pytorch torch.save() 保存特征向量

Pytorch torch.save() 保存特征向量

PyTorch保存和加载模型

每天讲解一点PyTorch 15model.load_state_dict torch.load torch.save

每天讲解一点PyTorch 15model.load_state_dict torch.load torch.save