pytorch中使用add_module添加网络子模块
Posted 非晚非晚
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch中使用add_module添加网络子模块相关的知识,希望对你有一定的参考价值。
之前有一篇文章介绍了使用Sequential、ModuleList和ModuleDict添加网络,除此之外,我们还可以通过add_module()
添加每一层,并且为每一层增加了一个单独的名字。add_module()可以快速地替换特定结构可以不用修改过多的代码。
add_module
的功能为Module添加一个子module,对应名字为name。使用方式如下:
add_module(name, module)
其中name为子模块的名字,使用这个名字可以访问特定的子module。module为我们自定义的子module。
一般情况下子module都是在A.init(self)中定义的,比如A中一个卷积子模块self.conv1 = torch.nn.Conv2d(…)
。此时,这个卷积模块在A的名字其实是’conv1’。
对比之下,add_module()函数就可以在A.init(self)以外定义A的子模块
。如定义同样的卷积子模块,可以通过A.add_module(‘conv1’, torch.nn.Conv2d(…))
。
- 代码举例一:
import torch
class Net3(torch.nn.Module):
def __init__(self):
super(Net3, self).__init__()
self.conv=torch.nn.Sequential()
self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
self.conv.add_module("relu1",torch.nn.ReLU())
self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
self.dense = torch.nn.Sequential()
self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
self.dense.add_module("relu2",torch.nn.ReLU())
self.dense.add_module("dense2",torch.nn.Linear(128, 10))
def forward(self, x):
conv_out = self.conv1(x)
res = conv_out.view(conv_out.size(0), -1)
out = self.dense(res)
return out
model3 = Net3()
print(model3)
输出:
Net3(
(conv): Sequential(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1): ReLU()
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(dense): Sequential(
(dense1): Linear(in_features=288, out_features=128, bias=True)
(relu2): ReLU()
(dense2): Linear(in_features=128, out_features=10, bias=True)
)
)
- 代码举例二:
from torch import nn
class Net_test(nn.Module):
def __init__(self):
super(Net_test,self).__init__()
self.conv_1 = nn.Conv2d(3,6,3)
self.add_module('conv_2', nn.Conv2d(6,12,3))
self.conv_3 = nn.Conv2d(12,24,3)
def forward(self,x):
x = self.conv_1(x)
x = self.conv_2(x)
x = self.conv_3(x)
return x
model = Net_test()
print(model)
输出:
Net_test(
(conv_1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
(conv_2): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
(conv_3): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1))
)
以上是关于pytorch中使用add_module添加网络子模块的主要内容,如果未能解决你的问题,请参考以下文章
pytorch中的顺序容器——torch.nn.Sequential
pytorch-卷积基本网络结构-提取网络参数-初始化网络参数