Unpickling 保存的 pytorch 模型会引发 AttributeError: Can't get attribute 'Net' on <module '__main__' 尽管内联
Posted
技术标签:
【中文标题】Unpickling 保存的 pytorch 模型会引发 AttributeError: Can\'t get attribute \'Net\' on <module \'__main__\' 尽管内联添加了类定义【英文标题】:Unpickling saved pytorch model throws AttributeError: Can't get attribute 'Net' on <module '__main__' despite adding class definition inlineUnpickling 保存的 pytorch 模型会引发 AttributeError: Can't get attribute 'Net' on <module '__main__' 尽管内联添加了类定义 【发布时间】:2019-08-24 14:34:22 【问题描述】:我正在尝试在烧瓶应用中提供 pytorch 模型。当我早些时候在 jupyter notebook 上运行它时,这段代码正在工作,但现在我在虚拟环境中运行它,显然即使类定义就在那里,它也无法获得属性“Net”。所有其他类似的问题都告诉我在同一个脚本中添加已保存模型的类定义。但它仍然不起作用。火炬版本是 1.0.1(保存的模型和 virtualenv 一样被训练) 我究竟做错了什么? 这是我的代码。
import os
import numpy as np
from flask import Flask, request, jsonify
import requests
import torch
from torch import nn
from torch.nn import functional as F
MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'
r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.sigmoid(self.fc1(x))
x = torch.sigmoid(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=-1)
model = torch.load('model.pth')
app = Flask(__name__)
@app.route("/")
def hello():
return "Binary classification example\n"
@app.route('/predict', methods=['GET'])
def predict():
x_data = request.args['x_data']
x_data = x_data.split()
x_data = list(map(float, x_data))
sample = np.array(x_data)
sample_tensor = torch.from_numpy(sample).float()
out = model(sample_tensor)
_, predicted = torch.max(out.data, -1)
if predicted.item() == 0:
pred_class = "Has no liver damage - ", predicted.item()
elif predicted.item() == 1:
pred_class = "Has liver damage - ", predicted.item()
return jsonify(pred_class)
这是完整的回溯:
Traceback (most recent call last):
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask", line 10, in <module>
sys.exit(main())
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 894, in main
cli.main(args=args, prog_name=name)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 557, in main
return super(FlaskGroup, self).main(*args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 717, in main
rv = self.invoke(ctx)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 1137, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 956, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
return callback(*args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/decorators.py", line 64, in new_func
return ctx.invoke(f, obj, *args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
return callback(*args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 767, in run_command
app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 293, in __init__
self._load_unlocked()
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 317, in _load_unlocked
self._app = rv = self.loader()
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 372, in load_app
app = locate_app(self, import_name, name)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 235, in locate_app
__import__(module_name)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/app.py", line 34, in <module>
model = torch.load('model.pth')
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 368, in load
return _load(f, map_location, pickle_module)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 542, in _load
result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' from '/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask'>
This 不能解决我的问题。我不想改变我坚持模型的方式。 torch.save() 在虚拟环境之外对我来说工作得很好。我不介意将类定义添加到脚本中。尽管如此,我还是想看看是什么导致了错误。
【问题讨论】:
这与那个无关。 torch.save() 在 virtualenv 之外对我来说工作正常。我只是想弄清楚如何修复错误。我不想改变对持久性建模的方式。 你是怎么save
模特的?你是保存整个模型还是只保存它的state_dict
?
整个模型。不是 state_dict。我可以加载它并在本地成功使用它。我不能在 virtualenv 中做到这一点。我正在尝试将其部署到 AWS Lambda
如果我只使用 state_dict 就可以了。尽管添加了类定义,但我试图理解为什么 pickle 会引发属性错误。
您的应用运行情况如何?您可以在代码中添加print(__name__)
行吗?我猜你的脚本的__name__
在保存泡菜时等于__main__
,但现在不同了,当你用烧瓶运行它时,会导致属性查找错误。
【参考方案1】:
这可能不是一个非常受欢迎的答案,但是,我发现dill
包在使我的代码工作方面非常一致。对我来说,我什至没有尝试加载模型,我正在尝试解压缩一个自定义对象来帮助我的东西,但由于某种原因它找不到它。我不知道为什么,但根据我的经验,莳萝似乎是腌制的更好选择:
# - path to files
path = Path(path2dataset).expanduser()
path2file_data_prep = Path(path2file_data_prep).expanduser()
# - create dag dataprep obj
print(f'path to data set path=')
dag_prep = SplitDagDataPreparation(path)
# - save data prep splits object
print(f'saving to path2file_data_prep=')
torch.save('data_prep': dag_prep, path2file_data_prep, pickle_module=dill)
# - load the data prep splits object to test it loads correctly
db = torch.load(path2file_data_prep, pickle_module=dill)
db['data_prep']
print(db)
return path2file_data_prep
【讨论】:
【参考方案2】:首先我初始化了一个空模型,然后加载了保存的模型,这出于某种原因解决了这个问题。
【讨论】:
我遇到了同样的问题。导入类定义为我解决了这个问题。【参考方案3】:(这是部分答案)
我认为 torch.save(model,'model.pt')
不能在命令提示符下工作,或者当模型从一个以 '__main__'
运行的脚本保存并从另一个脚本加载时。
原因是torch必须自动加载用于保存文件的模块,并从__name__
获取模块名称。
现在是部分部分:目前尚不清楚如何解决此问题,尤其是当您混合使用 virtualenvs 时。
感谢 Jatentaki 在这个方向上开始对话。
【讨论】:
以上是关于Unpickling 保存的 pytorch 模型会引发 AttributeError: Can't get attribute 'Net' on <module '__main__' 尽管内联的主要内容,如果未能解决你的问题,请参考以下文章
ValueError:Numpy中的非字符串名称仅在AWS Lambda上进行unpickling
unpickling 模型文件 python scikit-learn(管道(memory=None, steps=None, verbose=None))