如何使用 PyTorch 模型进行预测?

Posted

技术标签:

【中文标题】如何使用 PyTorch 模型进行预测?【英文标题】:How do I predict using a PyTorch model? 【发布时间】:2021-07-01 07:03:14 【问题描述】:

我创建了一个 pyTorch 模型来对图像进行分类。 我通过 state_dict 和整个模型保存了一次:

torch.save(model.state_dict(), "model1_statedict")
torch.save(model, "model1_complete")

如何使用这些模型? 我想用一些图片来检查它们是否良好。

我正在加载模型:

model = torch.load(path_model)
model.eval()

这没问题,但我不知道如何使用它来预测新图片。

【问题讨论】:

我编辑了你的问题,因为很遗憾这里不允许请求资源,例如教程 好的对不起,我不知道,谢谢编辑 【参考方案1】:
def predict(self, test_images):
    self.eval()
    # model is self(VGG class's object)
    
    count = test_images.shape[0]
    result_np = []
        
    for idx in range(0, count):
        # print(idx)
        img = test_images[idx, :, :, :]
        img = np.expand_dims(img, axis=0)
        img = torch.Tensor(img).permute(0, 3, 1, 2).to(device)
        # print(img.shape)
        pred = self(img)
        pred_np = pred.cpu().detach().numpy()
        for elem in pred_np:
            result_np.append(elem)
    return result_np

网络是 VGG-19 并参考我的源代码。

喜欢这样的架构:

class VGG(object):
    def __init__(self):
    ...


    def train(self, train_images, valid_images):
        train_dataset = torch.utils.data.Dataset(train_images)
        valid_dataset = torch.utils.data.Dataset(valid_images)

        trainloader = torch.utils.data.DataLoader(train_dataset)
        validloader = torch.utils.data.DataLoader(valid_dataset)

        self.optimizer = Adam(...)
        self.criterion = CrossEntropyLoss(...)
    
        for epoch in range(0, epochs):
            ...
            self.evaluate(validloader, model=self, criterion=self.criterion)
    ...

    def evaluate(self, dataloader, model, criterion):
        model.eval()
        for i, sample in enumerate(dataloader):
    ...

    def predict(self, test_images):
    
    ...

if __name__ == "__main__":
    network = VGG()
    trainset, validset = get_dataset()    # abstract function for showing
    testset = get_test_dataset()
    
    network.train(trainset, validset)

    result = network.predict(testset)

【讨论】:

【参考方案2】:

pytorch 模型是一个函数。您为它提供适当定义的输入,它会返回一个输出。如果您只是想直观地检查给定特定输入图像的输出,只需调用它:

model.eval()
output = model(example_image)

【讨论】:

以上是关于如何使用 PyTorch 模型进行预测?的主要内容,如果未能解决你的问题,请参考以下文章

pytorch闪电模型的输出预测

在Python中使用LSTM和PyTorch进行时间序列预测

如何修复 Pytorch-Forecasting 模型拟合关于序列元素的 ValueError

使用 pytorch-lightning 进行简单预测的示例

基于pytorch搭建多特征LSTM时间序列预测代码详细解读(附完整代码)

Pytorch Note56 Fine-tuning 通过微调进行迁移学习