python如何导入pth模型

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python如何导入pth模型相关的知识,希望对你有一定的参考价值。

在Python中导入pth模型有两种方式:

(1)使用torch.load():这是最常见的方法,可以从pth文件中加载模型,代码如下:

model = torch.load('model.pth')

(2)使用torch.nn.Module.load_state_dict():如果只是保存模型的参数(state_dict),而不是完整的模型,就可以使用这种方法来加载,代码如下:

model.load_state_dict(state_dict)

注意:torch.load()会返回一个包含类和参数的字典,而使用torch.nn.Module.load_state_dict()只需要返回参数。
参考技术A 在Python中导入PyTorch模型非常简单,您只需要使用PyTorch的 torch.load()函数即可加载模型。例如,下面的代码段将加载一个名为mymodel.pth的模型:

model = torch.load('mymodel.pth')

从名称中带有点的文件夹导入 - Python

【中文标题】从名称中带有点的文件夹导入 - Python【英文标题】:import from a folder with dots in name - Python 【发布时间】:2017-01-25 06:29:42 【问题描述】:

我有一个 python 包,如下所示:

package/
├── __init__.py
├── PyMySQL-0.7.6-py2.7.egg
├── pymysql
├── PyMySQL-0.7.x.pth  
└── tests.py

文件夹结构无法更改,因为它来自第三方库。

.pth文件的内容是

import sys; sys.__plen = len(sys.path)
./PyMySQL-0.7.6-py2.7.egg
import sys; new=sys.path[sys.__plen:]; del sys.path[sys.__plen:]; p=getattr(sys,'__egginsert',0); sys.path[p:p]=new; sys.__egginsert = p+len(new)

tests.py

中包含 pymysql 的最佳方式是什么

我显然不能使用from PyMySQL-0.7.6-py2.7.egg,因为文件夹名称包含点。

附:绝对路径未知,因为此代码应该部署到 AWS lambda

【问题讨论】:

我会尝试使用importlib.import_module 函数。但这不起作用,您需要复制 importlib(将其命名为 myimportlib)包并更改 _bootstrap._find_and_load_unlocked 的路径拆分算法。听起来很酷的项目,但我不想维护代码!我粘贴了源代码的 url,请注意这是 Python 提示,需要 python 3.6 才能运行。 github.com/python/cpython/tree/master/Lib/importlib 是的,绝对不是我想要的:D 但如果您无法避免问题,显然这是唯一的解决方案。就我个人而言,我相信这样做会很酷,并且可能对其他人有用。如果您使用实现,请随时通知我,我想看看如何测试这样的功能。 【参考方案1】:

问题是模块的 .pth 文件没有被使用。 Amazon 的 Lambda 构建说明让您将依赖项安装到与 Lambda 函数文件相同的根目录中。此目录未配置为“站点目录”(例如,虚拟环境中的 lib/pythonX.Y/site-packages/ 或用于全局安装的 /usr/lib/pythonX.Y/site-packages (CentOS) 或 /usr/local/lib/pythonX.Y/dist-packages (Ubuntu)),因此不评估 .pth 文件。

您可以将当前工作目录配置为站点目录。 The site module documentation 将创建 usercustomize.py 描述为添加自定义站点目录的约定,但 Lambda 提出了另一个问题:要自动评估,该文件应该在 ~/.local/lib/pythonX.Y/site-packages 中找到。我在我的 Lambda .zip 文件的根目录中创建了一个usercustomize.py,并首先将它导入到我的 Lambda 函数处理程序文件中。然后我可以导入使用 .pth 文件来配置点符号的模块。

usercustomize.py

import site
import os

site.addsitedir(os.getcwd())

lambda_handler.py

import usercustomize
from foo.bar import baz

使用此方法,您可以将依赖项安装到分发 zip 文件根目录中的 site-packages 目录中(os.getcwd() 变为 os.path.join(os.getcwd(), 'site-packages')),如果这对您很重要,则可以使结构更清晰。

【讨论】:

以上是关于python如何导入pth模型的主要内容,如果未能解决你的问题,请参考以下文章

如何加载和使用 PyTorch (.pth.tar) 模型

如何将自定义 Pytorch 模型转换为 torchscript(pth 到 pt 模型)?

如何使用 Detectron2 .pth 模型从存储中进行预测..?

python基础——python添加模块搜索路径和包的导入

Pytorch的pth模型转onnx,再用ONNX Runtime调用推理(附python代码)

从名称中带有点的文件夹导入 - Python