在无法访问模型类代码的情况下保存 PyTorch 模型
Posted
技术标签:
【中文标题】在无法访问模型类代码的情况下保存 PyTorch 模型【英文标题】:Saving PyTorch model with no access to model class code 【发布时间】:2020-04-04 20:23:41 【问题描述】:如何保存 PyTorch 模型而不需要在某处定义模型类?
免责声明:
在Best way to save a trained model in PyTorch? 中,没有解决方案(或有效的解决方案)可以在不访问模型类代码的情况下保存模型。
【问题讨论】:
【参考方案1】:如果您打算使用可用的 Pytorch 库(即 Python、C++ 或它支持的其他平台中的 Pytorch)进行推理,那么最好的方法是通过 TorchScript。
我认为最简单的方法是使用trace = torch.jit.trace(model, typical_input)
,然后使用torch.jit.save(trace, path)
。然后您可以使用torch.jit.load(path)
加载跟踪模型。
这是一个非常简单的例子。我们制作两个文件:
train.py
:
import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
x = torch.relu(self.linear(x))
return x
model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
print(model(x))
traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")
infer.py
:
import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
print(loaded_trace(x))
依次运行这些会给出结果:
python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])
python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])
结果是一样的,所以我们很好。 (注意,由于nn.Linear层初始化的随机性,这里每次的结果都会不同)。
TorchScript 提供了更复杂的架构和图形定义(包括 if 语句、while 循环等),可以保存在单个文件中,而无需在推理时重新定义图形。有关更高级的可能性,请参阅文档(上面链接)。
【讨论】:
使用torch脚本有什么缺点? 好吧,主要问题是您仍然需要某种 pytorch 环境。另外,如果您想继续训练,我想这将非常困难/不可能。有时它也可能有点错误/难以调试。但这基本上是 pytorch 对在 tensorflow 中轻松保存整个图形的回答。每个版本都在改进,并且已经非常好 imo。【参考方案2】:我建议您将 pytorch 模型转换为 onnx 并保存。可能是在不访问类的情况下存储模型的最佳方式。
【讨论】:
我对 onnx 不感兴趣,因为它有一些限制......并且不支持某些 pytorch 功能......【参考方案3】:没有解决方案(或工作解决方案)可以在不访问类的情况下保存模型。
你可以保存任何你喜欢的东西。
您可以保存模型,torch.save(model, filepath)
。它保存模型对象本身。
您可以只保存模型状态字典。
torch.save(model.state_dict(), filepath)
此外,你可以保存任何你喜欢的东西,因为torch.save
只是一个基于泡菜的保存。
state =
'hello_text': 'just the optimizer sd will be saved',
'optimizer': optimizer.state_dict(),
torch.save(state, filepath)
您可以在前一段时间查看what I wrote on torch.save
。
【讨论】:
我希望存在一些解决方法,因为 tensorflow 中有这样的选项。【参考方案4】:由一位核心 PyTorch 开发人员 (smth) 提供官方答案:
在没有代码的情况下加载 pytorch 模型存在限制。
第一个限制: 我们只保存类定义的源代码。除此之外,我们不会保存(例如该类所引用的包源)。
例如:
import foo
class MyModel(...):
def forward(input):
foo.bar(input)
这里的包foo
没有保存在模型检查点中。
第二个限制: 健壮地序列化 python 结构是有限制的。例如,默认的pickler 不能序列化 lambda。有一些帮助程序包可以序列化比标准更多的 python 结构,但它们仍然有局限性。 Dill 25 就是这样一个包。
鉴于这些限制,没有原始源文件就没有可靠的方法让 torch.load 工作。
【讨论】:
以上是关于在无法访问模型类代码的情况下保存 PyTorch 模型的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch 深度剖析:并行训练的 DP 和 DDP 分别在啥情况下使用及实例
Pytorch学习笔记——Sequential类参数管理与GPU