如何加载和使用 PyTorch (.pth.tar) 模型

Posted

技术标签:

【中文标题】如何加载和使用 PyTorch (.pth.tar) 模型【英文标题】:How can I load and use a PyTorch (.pth.tar) model 【发布时间】:2019-01-22 05:55:05 【问题描述】:

我对 Torch 不是很熟悉,我主要使用 Tensorflow。但是,我需要使用在 Torch 中重新训练的重新训练的初始模型。由于为我的特定应用重新训练初始模型需要大量计算资源,我想使用已经重新训练的模型。

此模型保存为.pth.tar 文件。

我希望能够先加载此模型。到目前为止,我已经能够弄清楚我必须使用以下内容:

model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')

这似乎可行,因为print(model) 打印出大量数字和其他值,我认为这些值是权重和偏差的值。

在此之后,我需要能够用它对图像进行分类。我一直无法弄清楚这一点。我必须如何格式化图像?图像是否应该转换为数组?在此之后,我必须如何将输入数据传递给网络?

【问题讨论】:

How to load and use a pretained PyTorch InceptionV3 model to classify an image的可能重复 【参考方案1】:

你基本上需要做和 tensorflow 一样的事情。也就是说,当您存储网络时,只会存储参数(即网络中的可训练对象),而不是“胶水”,这就是使用训练模型所需的所有逻辑。 因此,如果您有一个.pth.tar 文件,您可以加载它,从而覆盖已定义模型的参数值。

这意味着保存/加载模型的一般过程如下:

编写您的网络定义(即您的nn.Module 对象) 以您想要的方式训练或更改网络参数 使用torch.save保存参数 当您想使用该网络时,请使用 nn.Module 对象的相同定义来首先实例化一个 pytorch 网络 然后 使用torch.load 覆盖网络参数的值

这里有一些关于如何做到这一点的参考讨论:pytorch forums

这是一个超短的mwe:

# to store
torch.save(
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict(),
, 'filename.pth.tar')

# to load
checkpoint = torch.load('filename.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

【讨论】:

@tamtam 请不要劫持不同问题的答案。上面的答案确实回答了所提出的问题,因为它是关于加载参数的一般流程。不过,我已经更新了答案,也给出了一个 mwe。 OP 从未批准您的回答,而且您还是遵守了我的问题……所以这不是劫持。将其视为澄清请求。感谢mwe,它澄清了事情。

以上是关于如何加载和使用 PyTorch (.pth.tar) 模型的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch自定义加载预训练权重

PyTorch参数模型转换为PT模型

如何加载数据以及如何使用 pytorch 进行数据扩充

如何将pickle文件中的数据集加载到PyTorch中?

如何使用pytorch同时迭代两个数据加载器?

pytorch - 如何从 DistributedDataParallel 学习中保存和加载模型