pytorch中的add_module函数

Posted Masked Prometheus

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch中的add_module函数相关的知识,希望对你有一定的参考价值。

现只讲在自定义网络中add_module的作用。

总结:

在自定义网络的时候,由于自定义变量不是Module类型(例如,我们用List封装了几个网络),所以pytorch不会自动注册网络模块add_module函数用来为网络添加模块的,所以我们可以使用这个函数手动添加自定义的网络模块。当然,这种情况,我们也可以使用ModuleList来封装自定义模块,pytorch就会自动注册了。

 

Let\'t start!

add_module函数是在自定义网络添加子模块,例如,当我们自定义一个网络肤过程中,我们既可以

(1)通过self.module=xxx_module的方式(如下面第3行代码),添加网络模块;

(2)通过add_module函数对网络中添加模块。

(3)通过用nn.Sequential对模块进行封装等等。

 1 class NeuralNetwork(nn.Module):
 2     def __init__(self):
 3         super(NeuralNetwork, self).__init__()
 4         self.layers = nn.Linear(28*28,28*28)
 5 #         self.add_module(\'layers\',nn.Linear(28*28,28*28))  #  跟上面的方式等价
 6         self.linear_relu_stack = nn.Sequential(
 7             nn.Linear(28*28, 512),
 8             nn.ReLU()
 9         )
10 
11     def forward(self, x):
12         for layer in layers:
13             x = layer(x)
14         logits = self.linear_relu_stack(x)
15         return logits

我们实例化类,然后输出网络的模块看一下:

1 0 Linear(in_features=784, out_features=784, bias=True)
2 1 Sequential(
3   (0): Linear(in_features=784, out_features=512, bias=True)
4   (1): ReLU()
5 )

会发现,上面定义的网络子模块都有:Linear和Sequential。

 

但是,有时候pytorch不会自动给我们注册模块,我们需要根据传进来的参数对网络进行初始化,例如:

 

 1 class NeuralNetwork(nn.Module):
 2     def __init__(self, layer_num):
 3         super(NeuralNetwork, self).__init__()
 4         self.layers = [nn.Linear(28*28,28*28) for _ in range(layer_num)]
 5         self.linear_relu_stack = nn.Sequential(
 6             nn.Linear(28*28, 512),
 7             nn.ReLU()
 8         )
 9 
10     def forward(self, x):
11         for layer in layers:
12             x = layer(x)
13         logits = self.linear_relu_stack(x)
14         return logits

对此我们再初始化一个实例,然后看下网络中的模块:

1 model = NeuralNetwork(2)
2 for index,item in enumerate(model.children()):
3     print(index,item)

输出结果就是:

0 Sequential(
  (0): Linear(in_features=784, out_features=512, bias=True)
  (1): ReLU()
) 

 

你会发现定义的Linear模块都不见了,而上面定义的时候,明明都制订了。这是因为pytorch在注册模块的时候,会查看成员的类型,如果成员变量类型是Module的子类,那么pytorch就会注册这个模块,否则就不会。

这里的self.layers是python中的List类型,所以不会自动注册,那么就需要我们再定义后,手动注册(下图黄色标注部分):

 1 class NeuralNetwork(nn.Module):
 2     def __init__(self, layer_num):
 3         super(NeuralNetwork, self).__init__()
 4         self.layers = [nn.Linear(28*28,28*28) for _ in range(layer_num)]
 5         for i,layer in enumerate(self.layers):
 6             self.add_module(\'layer_{}\'.format(i),layer)
 7         self.linear_relu_stack = nn.Sequential(
 8             nn.Linear(28*28, 512),
 9             nn.ReLU()
10         )
11 
12     def forward(self, x):
13         for layer in layers:
14             x = layer(x)
15         logits = self.linear_relu_stack(x)
16         return logits

这样我们再输出模型的子模块的时候,就会得到:

model = NeuralNetwork(4)
for index,item in enumerate(model.children()):
    print(index,item)

# output
#0 Linear(in_features=784, out_features=784, bias=True)
#1 Linear(in_features=784, out_features=784, bias=True)
#2 Linear(in_features=784, out_features=784, bias=True)
#3 Linear(in_features=784, out_features=784, bias=True)
#4 Sequential(
#  (0): Linear(in_features=784, out_features=512, bias=True)
#  (1): ReLU()
#)

就会看到,已经有了自己注册的模块。

 

当然,也可能觉得这种方式比较麻烦,每次都要自己注册下,那能不能有一个类似List的类,在定义的时候就封装一下呢? 

可以,使用nn.ModuleList封装一下即可达到相同的效果。

class NeuralNetwork(nn.Module):
    def __init__(self, layer_num):
        super(NeuralNetwork, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(28*28,28*28) for _ in range(layer_num)])
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU()
        )

    def forward(self, x):
        for layer in layers:
            x = layer(x)
        logits = self.linear_relu_stack(x)
        return logits

 

参考:
1. 博客THE PYTORCH ADD_MODULE() FUNCTION link
2. pytorch 官方文档 中文链接 English version

以上是关于pytorch中的add_module函数的主要内容,如果未能解决你的问题,请参考以下文章

pytorch中使用add_module添加网络子模块

pytorch中的顺序容器——torch.nn.Sequential

pytorch-卷积基本网络结构-提取网络参数-初始化网络参数

PyTorch 中的 tensordot 以及 einsum 函数介绍

说话人识别损失函数的PyTorch实现与代码解读

调用模板化成员函数:帮助我理解另一个 *** 帖子中的代码片段