Torch 老司机必看 | 在 PyTorch 中加载 Torch 模型

Posted 集智学园

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Torch 老司机必看 | 在 PyTorch 中加载 Torch 模型相关的知识,希望对你有一定的参考价值。

Torch 是 PyTorch 的先行者,也是由 Facebook 维护的易用深度学习框架。然而因为 Torch 使用的语言是比较小众的 Lua,导致 Torch 在深度学习圈(尤其是国内)没有非常的流行起来。因此 Facebook 的相关团队才开发了 PyTorch 这一基于更大众的 Python 语言的深度学习框架。

虽然当时 Torch 框架并没有像今天的 PyTorch 这样流行,但是 Torch 当时还是积攒了一定的用户群,网上也有很多有趣的深度学习项目是用 Torch 编写的。如果你苦于不懂 Lua 语言,又想尝试一下那些基于 Torch 的项目,或者想在 PyTorch 中使用 Torch 训练过的模型,今天我就给大家简单介绍一下将 Torch 模型转移到 PyTorch 中去的方法。


我们可以将 Torch 中的张量、数字、数据表、神经网络模型、字符串等打包成一个“.t7”文件,然后使用 PyTorch 的“load_lua”方法来加载它们。

下面我会演示一个从 Torch 中导出一个张量,并使用 PyTorch 加载这个张量的方法。

首先我会使用 Lua 语言在 Torch 框架中创建一个张量(Tensor) a,并将其导出到文件 a.t7 中。

th> a = torch.randn(10)

                                                                      [0.0027s]

th> torch.save('a.t7', a)

                                                                      [0.0010s]

th> a

-1.4479

 1.3707

 0.5663

-1.0590

 0.0706

-1.6495

-1.0805

 0.8277

-0.4595

 0.1237

[torch.DoubleTensor of size 10]


                                                                      [0.0033s]

然后我们到 PyTorch 中(注意此时使用的语言是 Python),使用 load_lua 方法加载上面导出的 a.t7

In [1]: import torch

In [2]: from torch.utils.serialization import load_lua

In [3]: a = load_lua('a.t7')

# 可以观察到在 Torch 中建立的 Tensor a 已经原封不动的转移到 PyTorch 中了。

In [4]: a

Out[4]:

-1.4479

 1.3707

 0.5663

-1.0590

 0.0706

-1.6495

-1.0805

 0.8277

-0.4595

 0.1237

[torch.DoubleTensor of size 10]

下面我们再演示一个稍微复杂点的例子。我们导出&加载一个2层的序列神经网络模型。

首先我们在 Torch 中建立这个序列模型 a,它有一个全连接层,其后紧跟一个 ReLU 进行非线性输出。建立后我们同样将它打包导出为 a.t7

th> a = nn.Sequential():add(nn.Linear(10, 20)):add(nn.ReLU())

                                                                      [0.0001s]

th> a

nn.Sequential {

  [input -> (1) -> (2) -> output]

  (1): nn.Linear(10 -> 20)

  (2): nn.ReLU

}

                                                                      [0.0001s]

th> torch.save('a.t7', a)

                                                                      [0.0008s]

th>

同样的,我们只需要在 PyTorch 中加载 a.t7,就可以加载刚刚在 Torch 中建立的模型啦,是不是超简单!

In [5]: a = load_lua('a.t7')


In [6]: a

Out[6]:

nn.Sequential {

  [input -> (0) -> (1) -> output]

  (0): nn.Linear(10 -> 20)

  (1): nn.ReLU

}


In [7]: a.__class__

Out[7]: torch.legacy.nn.Sequential.Sequential

好!这次的 PyTorch 小教程到这里就结束啦~

我们以后会给大家推送更多实用的 PyTorch 使用技巧的!

See You~~Torch 老司机必看 | 在 PyTorch 中加载 Torch 模型




推荐阅读:

女生看了会流泪 | 训练一个会“卸妆”的深度学习模型

“火炬上的深度学习”之缘起


为什么他们要来集智AI学园学习 PyTorch?




获取更多更有趣的AI教程吧!

学园网站:campus.swarma.org

 商务合作|zhangqian@swarma.org     

投稿转载|wangjiannan@swarma.org

点击学习PyTorch

以上是关于Torch 老司机必看 | 在 PyTorch 中加载 Torch 模型的主要内容,如果未能解决你的问题,请参考以下文章

几个精髓磁力搜索引擎,各位“绅士”请低调使用~

必看DBA新手常见问题解答

pytorch torch类

PyTorch中通过torch.save保存模型和torch.load加载模型介绍

Pytorch 中 torch.cat() 函数解析

torch.Tensor 和 torch.tensor