pytorch中使用add_module添加网络子模块

Posted 非晚非晚

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch中使用add_module添加网络子模块相关的知识,希望对你有一定的参考价值。

之前有一篇文章介绍了使用Sequential、ModuleList和ModuleDict添加网络,除此之外,我们还可以通过add_module()添加每一层,并且为每一层增加了一个单独的名字。add_module()可以快速地替换特定结构可以不用修改过多的代码。

add_module的功能为Module添加一个子module,对应名字为name。使用方式如下:

add_module(name, module)

其中name为子模块的名字,使用这个名字可以访问特定的子module。module为我们自定义的子module。

一般情况下子module都是在A.init(self)中定义的,比如A中一个卷积子模块self.conv1 = torch.nn.Conv2d(…)。此时,这个卷积模块在A的名字其实是’conv1’。

对比之下,add_module()函数就可以在A.init(self)以外定义A的子模块。如定义同样的卷积子模块,可以通过A.add_module(‘conv1’, torch.nn.Conv2d(…))

  • 代码举例一:
import torch

class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))
 
  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out
 
model3 = Net3()
print(model3)

输出:

Net3(
  (conv): Sequential(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (dense): Sequential(
    (dense1): Linear(in_features=288, out_features=128, bias=True)
    (relu2): ReLU()
    (dense2): Linear(in_features=128, out_features=10, bias=True)
  )
)
  • 代码举例二:
from torch import nn

class Net_test(nn.Module):
    def __init__(self):
        super(Net_test,self).__init__()
        self.conv_1 = nn.Conv2d(3,6,3)
        self.add_module('conv_2', nn.Conv2d(6,12,3))
        self.conv_3 = nn.Conv2d(12,24,3)
        
    def forward(self,x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        return x
    
model = Net_test()
print(model)

输出:

Net_test(
  (conv_1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv_2): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
  (conv_3): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1))
)

以上是关于pytorch中使用add_module添加网络子模块的主要内容,如果未能解决你的问题,请参考以下文章

pytorch中的add_module函数

pytorch中的顺序容器——torch.nn.Sequential

pytorch-卷积基本网络结构-提取网络参数-初始化网络参数

pytorch中的神经网络子模块(线性模块)——torch.nn.Linear

PyTorch:在训练中添加验证错误

基于PyTorch,如何构建一个简单的神经网络