如何在 PyTorch 中保存模型架构?

Posted

技术标签:

【中文标题】如何在 PyTorch 中保存模型架构?【英文标题】:How to save model architecture in PyTorch? 【发布时间】:2020-01-05 00:33:29 【问题描述】:

我知道我可以通过torch.save(model.state_dict(), FILE)torch.save(model, FILE) 保存模型。但是它们都没有保存模型的架构。

那么我们如何在 PyTorch 中保存模型的架构,例如在 Tensorflow 中创建 .pb 文件?我想对我的模型应用不同的调整。如果我无法保存模型的架构,我是否有比每次都复制整个类定义并创建一个新类更好的方法?

【问题讨论】:

架构是什么意思? torch.save(model, FILE) 应该可以正常工作 我不认为你可以那么容易地保存模型架构。我只知道保存nn.Sequential 模型的字符串。你最终找到了什么解决方案?顺便说一句,人们将架构称为计算的实际 DAG,而模型通常意味着具有给定 DAG 的 NN + 参数。 我最终将超参数保存在 .sh 文件中。所以我可以在不复制代码的情况下训练很多模型。听起来有点愚蠢,但对我来说似乎是最简单的方法。 @查理帕克 【参考方案1】:

你可以参考this文章了解如何保存分类器。要对模型进行调整,您可以创建一个新模型,它是现有模型的子模型。


class newModel( oldModelClass):
    def __init__(self):
        super(newModel, self).__init__()

通过此设置,newModel 具有所有层以及oldModelClass 的转发功能。如果需要调整,可以在__init__函数中定义新的层,然后编写一个新的forward函数来定义它。

【讨论】:

【参考方案2】:

保存所有参数 (state_dict) 和所有 Modules 是不够的,因为有操作张量的操作,但仅反映在具体实现的实际代码中(例如, reshapeing in ResNet)。

此外,网络可能没有固定且预先确定的计算图:您可以想象一个具有分支或循环(循环)的网络。

因此,您必须保存实际代码。

或者,如果网络中没有分支/循环,您可以保存计算图,例如,参见this post。

您还应该考虑使用onnx 导出您的模型,并拥有一个既能捕获训练的权重又能捕获计算图的表示。

【讨论】:

【参考方案3】:

关于实际问题:

那么我们如何在 PyTorch 中保存模型的架构,就像在 Tensorflow 中创建一个 .pb 文件一样?

答案是:你不能

有没有办法在不声明类定义的情况下加载经过训练的模型? 我希望加载模型架构和参数。

不,你必须先加载类定义,这是 python 酸洗限制。

https://discuss.pytorch.org/t/how-to-save-load-torch-models/718/11

不过,这篇 PyTorch 帖子中还列出了其他选项(可能您已经看过其中的大部分):

https://pytorch.org/tutorials/beginner/saving_loading_models.html

【讨论】:

【参考方案4】:

PyTorch 序列化模型进行推理的方式是使用 torch.jit 将模型编译为 TorchScript。

PyTorch 的 TorchScript 支持比 TensorFlow 更高级的控制流,因此可以通过跟踪 (torch.jit.trace) 或编译 Python 模型代码 (torch.jit.script) 进行序列化。

很好的参考资料:

解释这一点的视频:https://www.youtube.com/watch?app=desktop&v=2awmrMRf0dA

文档:https://pytorch.org/docs/stable/jit.html

【讨论】:

以上是关于如何在 PyTorch 中保存模型架构?的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch如何保存训练好的模型

在无法访问模型类代码的情况下保存 PyTorch 模型

我如何知道 Pytorch 中预训练模型的架构?

如果我希望 OpenCV dnn 模块可以加载 PyTorch 的模型,我应该如何保存它

pytorch如何给预训练模型添加新的层

如何加载和使用 PyTorch (.pth.tar) 模型