如何将 Pytorch 模型导入 MATLAB
Posted
技术标签:
【中文标题】如何将 Pytorch 模型导入 MATLAB【英文标题】:How to import Pytorch model into MATLAB 【发布时间】:2018-12-05 05:31:25 【问题描述】:我在 Pytorch 中创建了一个模型,我希望将其转移到 MATLAB,这里显示了一个最小的示例
import torch.nn as nn
import torch
class cnn(nn.Module):
def __init__(self):
super(cnn, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(10, 1),
nn.ReLU(True)
)
def forward(self, x):
out = self.fc1(x)
return out
the_net = cnn()
torch.save(the_net,'desperation.h5')
然后我在 MATLAB 中调用
net = importKerasLayers('desperation.h5')
这给出了错误信息
Error using importKerasLayers (line 104)
Unable to read HDF5 file 'desperation.h5'. The error message was: 'The filename specified was either
not found on the MATLAB path or it contains unsupported characters.''
文件在路径上,我可以将模型加载回 Python。我真正想要的是任何允许我将模型从 Pytorch 转移到 MATLAB 而无需手动复制所有权重的解决方案。
我正在运行 MATLAB 2018b、Python 3.6 和 Pytorch 0.4.0
【问题讨论】:
importKerasLayers 是否具有读取 pytorch 权重的功能? 我不确定,最好我可以告诉 torch.save 只是以 pytorch 格式保存,无论你放在什么结尾,这就是为什么这不起作用。无论如何,权重只是浮点数,所以主要的是包含权重的结构。这就是为什么我想先把它变成keras格式,这样matlab才能理解结构。 你看过onnx吗? 我有,但我不知道如何将 onnx 模型加载到 MATLAB 中。 【参考方案1】:我过去一直在使用这个工具并取得了一些成功:https://github.com/albanie/mcnPyTorch 从 Pytorch 转到 MatConvNet。
【讨论】:
以上是关于如何将 Pytorch 模型导入 MATLAB的主要内容,如果未能解决你的问题,请参考以下文章