在无法访问模型类代码的情况下保存 PyTorch 模型

Posted

技术标签:

【中文标题】在无法访问模型类代码的情况下保存 PyTorch 模型【英文标题】:Saving PyTorch model with no access to model class code 【发布时间】:2020-04-04 20:23:41 【问题描述】:

如何保存 PyTorch 模型而不需要在某处定义模型类?


免责声明

在Best way to save a trained model in PyTorch? 中,没有解决方案(或有效的解决方案)可以在不访问模型类代码的情况下保存模型。

【问题讨论】:

【参考方案1】:

如果您打算使用可用的 Pytorch 库(即 Python、C++ 或它支持的其他平台中的 Pytorch)进行推理,那么最好的方法是通过 TorchScript。

我认为最简单的方法是使用trace = torch.jit.trace(model, typical_input),然后使用torch.jit.save(trace, path)。然后您可以使用torch.jit.load(path) 加载跟踪模型。

这是一个非常简单的例子。我们制作两个文件:

train.py

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = torch.relu(self.linear(x))
        return x

model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
    print(model(x))
    traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")

infer.py

import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
    print(loaded_trace(x))

依次运行这些会给出结果:

python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

结果是一样的,所以我们很好。 (注意,由于nn.Linear层初始化的随机性,这里每次的结果都会不同)。

TorchScript 提供了更复杂的架构和图形定义(包括 if 语句、while 循环等),可以保存在单个文件中,而无需在推理时重新定义图形。有关更高级的可能性,请参阅文档(上面链接)。

【讨论】:

使用torch脚本有什么缺点? 好吧,主要问题是您仍然需要某种 pytorch 环境。另外,如果您想继续训练,我想这将非常困难/不可能。有时它也可能有点错误/难以调试。但这基本上是 pytorch 对在 tensorflow 中轻松保存整个图形的回答。每个版本都在改进,并且已经非常好 imo。【参考方案2】:

我建议您将 pytorch 模型转换为 onnx 并保存。可能是在不访问类的情况下存储模型的最佳方式。

【讨论】:

我对 onnx 不感兴趣,因为它有一些限制......并且不支持某些 pytorch 功能......【参考方案3】:

没有解决方案(或工作解决方案)可以在不访问类的情况下保存模型。

你可以保存任何你喜欢的东西。

您可以保存模型,torch.save(model, filepath)。它保存模型对象本身。

您可以只保存模型状态字典。

torch.save(model.state_dict(), filepath)

此外,你可以保存任何你喜欢的东西,因为torch.save 只是一个基于泡菜的保存。

state = 
    'hello_text': 'just the optimizer sd will be saved',
    'optimizer': optimizer.state_dict(),


torch.save(state, filepath)

您可以在前一段时间查看what I wrote on torch.save

【讨论】:

我希望存在一些解决方法,因为 tensorflow 中有这样的选项。【参考方案4】:

由一位核心 PyTorch 开发人员 (smth) 提供官方答案:

在没有代码的情况下加载 pytorch 模型存在限制。

第一个限制: 我们只保存类定义的源代码。除此之外,我们不会保存(例如该类所引用的包源)。

例如:

import foo

class MyModel(...):
    def forward(input):
        foo.bar(input)

这里的包foo没有保存在模型检查点中。

第二个限制: 健壮地序列化 python 结构是有限制的。例如,默认的pickler 不能序列化 lambda。有一些帮助程序包可以序列化比标准更多的 python 结构,但它们仍然有局限性。 Dill 25 就是这样一个包。

鉴于这些限制,没有原始源文件就没有可靠的方法让 torch.load 工作。

【讨论】:

以上是关于在无法访问模型类代码的情况下保存 PyTorch 模型的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch 模型保存错误:“无法腌制本地对象”

PyTorch 深度剖析:并行训练的 DP 和 DDP 分别在啥情况下使用及实例

Pytorch模型保存与加载,并在加载的模型基础上继续训练

Pytorch学习笔记——Sequential类参数管理与GPU

Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存

如何在 PyTorch 中保存模型架构?