如何使用 torch.hub.load 加载本地模型?

Posted

技术标签:

【中文标题】如何使用 torch.hub.load 加载本地模型?【英文标题】:How do I load a local model with torch.hub.load? 【发布时间】:2021-07-21 22:50:41 【问题描述】:

我需要避免从网上下载模型(由于安装机器的限制)。

这可行,但从网上下载模型

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)

我已将.pth 文件和hubconf.py 文件放在/tmp/ 文件夹中,并将我的代码更改为

model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True, source='local')

但令我惊讶的是,它仍然从互联网上下载模型。我究竟做错了什么?如何在本地加载模型?

为了给您提供更多细节,我在运行时具有只读卷的 Docker 容器中执行所有这些操作,这就是下载新文件失败的原因。

【问题讨论】:

在一些早期版本的 PyTorch 中似乎没有本地加载选项。你用的是哪个版本? 收集torch==1.8.1 下载torch-1.8.1-cp38-cp38-manylinux1_x86_64.whl (804.1 MB) 收集torchsummary==1.5.1 下载torchsummary-1.5.1-py3-none -any.whl (2.8 kB) 收集torchvision==0.9.1 下载torchvision-0.9.1-cp38-cp38-manylinux1_x86_64.whl (17.4 MB) pretrained=True,s ource 附近的代码在语法上似乎不正确。是原版的吗? 【参考方案1】:

您可以采用两种方法在没有 Internet 连接的情况下在机器上获取可交付模型。

    在普通机器上加载带有预训练模型的 DeepLab,使用JIT 编译器将其导出为图形,然后放入机器中。脚本很容易理解:

     # To export
     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval()
     traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
     traced_graph.save('DeepLab.pth')
    
     # To load
     model = torch.jit.load('DeepLab.pth').eval().to(device)
    

    在这种情况下,权重和网络结构被保存为计算图,因此您不需要任何额外的文件。

    看看torchvision's GitHub repository。

    对于具有 Resnet101 主干权重的 DeepLabV3,有一个 download URL。

    您可以下载这些权重一次,然后使用带有 pretrained=False 标志的 torchvision 的 deeplab 并手动加载权重。

     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False)
     model.load_state_dict(torch.load('downloaded weights path'))
    

    请注意,在 state dict 中可能有一个 ['state_dict'] 或一些类似的父键,您可以在其中使用:

     model.load_state_dict(torch.load('downloaded weights path')['state_dict'])
    

【讨论】:

什么是 H 和 W? 我认为它是 deeplab 规范指定的最小 224。当我尝试这个跟踪错误:module._c._create_method_from_trace( RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for list, use a tuple` 代替。对于dict,请改用NamedTuple)。如果您绝对需要它并且知道副作用,请将 strict=False 传递给 trace() 以允许这种行为。` 用 strict=False 试过了,到目前为止似乎有效。我将启动 docker 环境,看看是否可行 H 和 W 是图像的高度和宽度(张量),你是对的 - DeepLabV3 要求的最小尺寸是 224x224。 load_state_dict(chkpt, strict=False) 是仅加载所需权重的正确方法。最后的建议是 - 通过可视化测试模型输出,不要立即部署。 由于 DeepLab 的输出格式,jit 失败:OrderedDict -> ['out', 'aux']。但是,通过将 strict=False 传递给 JIT 跟踪器,它将编译为仅输出 ['out'] 且大小为 [1, 21, H, W] 的图。

以上是关于如何使用 torch.hub.load 加载本地模型?的主要内容,如果未能解决你的问题,请参考以下文章

基于flask和网页端部署yolo自训练模型

基于flask和网页端部署yolo自训练模型

基于flask和网页端部署yolo自训练模型

Yolov5自定义对象检测模型未加载

NLP Transformers:获得固定句子嵌入向量形状的最佳方法?

如何从 python 中的预训练模型中获取权重并在 tensorflow 中使用它?