Python中的模块导入无法加载Pytorch nn.Module的派生类

Posted

技术标签:

【中文标题】Python中的模块导入无法加载Pytorch nn.Module的派生类【英文标题】:Derived Class of Pytorch nn.Module Cannot be Loaded by Module Import in Python 【发布时间】:2020-08-18 19:28:41 【问题描述】:

将 Python 3.6 与 Pytorch 1.3.1 结合使用。我注意到当整个模块被导入另一个模块时,一些保存的 nn.Modules 无法加载。举个例子,这里是一个最小工作示例的模板。

#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'

from torch import nn
class NN(nn.Module):##NN network
    # Initialisation and other class methods

networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_fold.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
    # Some testing snippets
    pass

当我直接在 shell 中运行它时,整个文件工作得很好。但是,当我想使用该类并使用此代码将神经网络加载到另一个文件中时,它会失败。

#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *

错误为AttributeError: Can't get attribute 'NN' on <module '__main__'>

在 Pytorch 中加载保存的变量或导入模块是否与其他常见的 Python 库不同?一些帮助或指向根本原因的指针将不胜感激。

【问题讨论】:

【参考方案1】:

当您使用torch.save(model, PATH) 保存模型时,整个对象将使用pickle 进行序列化,这不会保存类本身,而是包含该类的文件的路径,因此在加载模型时完全相同的目录和需要文件结构才能找到正确的类。运行 Python 脚本时,该文件的模块为 __main__,因此如果要加载该模块,则必须在运行的脚本中定义 NN 类。

这样很不灵活,所以推荐的做法是不保存整个模型,而是只保存状态字典,只保存模型的参数。

# Save the state dictionary of the model
torch.save(model.state_dict(), PATH)

之后,可以加载状态字典并将其应用于您的模型。

from dnn_predict import NN

# Create the model (will have randomly initialised parameters)
model = NN()

# Load the previously saved state dictionary
state_dict = torch.load(PATH)

# Apply the state dictionary to the model
model.load_state_dict(state_dict)

有关状态字典和保存/加载模型的更多详细信息:PyTorch - Saving and Loading Models

【讨论】:

以上是关于Python中的模块导入无法加载Pytorch nn.Module的派生类的主要内容,如果未能解决你的问题,请参考以下文章

权重和偏差扫描无法使用 pytorch 闪电导入模块

python 3.6 on win 10 editdistance模块无法导入(DLL加载失败问题)

出现错误:DLL 加载失败:操作系统无法运行 %1 - Python 2.7;报废模块;导入密码学

我在尝试加载 Pytorch 模型时遇到问题:“无法在模块中找到身份”

Python 模块化 模块搜索顺序重复导入模块加载列表

python c api无法将任何模块导入新创建的模块