如何在 PyTorch 中保存模型架构?
Posted
技术标签:
【中文标题】如何在 PyTorch 中保存模型架构?【英文标题】:How to save model architecture in PyTorch? 【发布时间】:2020-01-05 00:33:29 【问题描述】:我知道我可以通过torch.save(model.state_dict(), FILE)
或torch.save(model, FILE)
保存模型。但是它们都没有保存模型的架构。
那么我们如何在 PyTorch 中保存模型的架构,例如在 Tensorflow 中创建 .pb
文件?我想对我的模型应用不同的调整。如果我无法保存模型的架构,我是否有比每次都复制整个类定义并创建一个新类更好的方法?
【问题讨论】:
架构是什么意思?torch.save(model, FILE)
应该可以正常工作
我不认为你可以那么容易地保存模型架构。我只知道保存nn.Sequential
模型的字符串。你最终找到了什么解决方案?顺便说一句,人们将架构称为计算的实际 DAG,而模型通常意味着具有给定 DAG 的 NN + 参数。
我最终将超参数保存在 .sh
文件中。所以我可以在不复制代码的情况下训练很多模型。听起来有点愚蠢,但对我来说似乎是最简单的方法。 @查理帕克
【参考方案1】:
你可以参考this文章了解如何保存分类器。要对模型进行调整,您可以创建一个新模型,它是现有模型的子模型。
class newModel( oldModelClass):
def __init__(self):
super(newModel, self).__init__()
通过此设置,newModel 具有所有层以及oldModelClass
的转发功能。如果需要调整,可以在__init__
函数中定义新的层,然后编写一个新的forward函数来定义它。
【讨论】:
【参考方案2】:保存所有参数 (state_dict
) 和所有 Modules 是不够的,因为有操作张量的操作,但仅反映在具体实现的实际代码中(例如, reshape
ing in ResNet)。
此外,网络可能没有固定且预先确定的计算图:您可以想象一个具有分支或循环(循环)的网络。
因此,您必须保存实际代码。
或者,如果网络中没有分支/循环,您可以保存计算图,例如,参见this post。
您还应该考虑使用onnx
导出您的模型,并拥有一个既能捕获训练的权重又能捕获计算图的表示。
【讨论】:
【参考方案3】:关于实际问题:
那么我们如何在 PyTorch 中保存模型的架构,就像在 Tensorflow 中创建一个 .pb 文件一样?
答案是:你不能
有没有办法在不声明类定义的情况下加载经过训练的模型? 我希望加载模型架构和参数。
不,你必须先加载类定义,这是 python 酸洗限制。
https://discuss.pytorch.org/t/how-to-save-load-torch-models/718/11
不过,这篇 PyTorch 帖子中还列出了其他选项(可能您已经看过其中的大部分):
https://pytorch.org/tutorials/beginner/saving_loading_models.html
【讨论】:
【参考方案4】:PyTorch 序列化模型进行推理的方式是使用 torch.jit
将模型编译为 TorchScript。
PyTorch 的 TorchScript 支持比 TensorFlow 更高级的控制流,因此可以通过跟踪 (torch.jit.trace
) 或编译 Python 模型代码 (torch.jit.script
) 进行序列化。
很好的参考资料:
解释这一点的视频:https://www.youtube.com/watch?app=desktop&v=2awmrMRf0dA
文档:https://pytorch.org/docs/stable/jit.html
【讨论】:
以上是关于如何在 PyTorch 中保存模型架构?的主要内容,如果未能解决你的问题,请参考以下文章