PyTorch网络搭建中*list的用法解析
Posted 算法与编程之美
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch网络搭建中*list的用法解析相关的知识,希望对你有一定的参考价值。
问题
stage1 = nn.Sequential(
nn.Sequential(
nn.Conv2d(16, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1, 1),
nn.ReLU(),
)
)
stage3 = nn.Sequential(
nn.Conv2d(3, 16, 3, 1, 1),
list(stage1.children())[0], #! 但是可以把list中的元素加进来
)
stage4 = nn.Sequential(
nn.Conv2d(3, 16, 3, 1, 1),
*list(stage1.children())[0], #! 解包序列后再将每个层加入进来
)
stage3和stage4都可以添加到nn.Sequential()中,二者的区别是什么?
方法
import torch
from torch import nn
stage1 = nn.Sequential(
nn.Sequential(
nn.Conv2d(16, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1, 1),
nn.ReLU(),
)
)
# stage2 = nn.Sequential(
# nn.Conv2d(3, 16, 3, 1, 1),
# list(stage1.children()), #! 不能把一个list加进来,因为list is not a Module subclass
# )
# print(stage2)
stage3 = nn.Sequential(
nn.Conv2d(3, 16, 3, 1, 1),
list(stage1.children())[0], #! 但是可以把list中的元素加进来
)
print(stage3)
'''stage3输出:
Sequential(
(0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Sequential(
(0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
)
)
'''
'''stage3与stage4的主要区别是:
- stage3将整个Sequential加入进来,
- 而stage4首先将Sequential解包而后再加入进来;
'''
stage4 = nn.Sequential(
nn.Conv2d(3, 16, 3, 1, 1),
*list(stage1.children())[0], #! 解包序列后再将每个层加入进来
)
print(stage4)
''' stage4输出:
Sequential(
(0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): ReLU()
(3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): ReLU()
)
'''
结语
以上是关于PyTorch网络搭建中*list的用法解析的主要内容,如果未能解决你的问题,请参考以下文章
睿智的目标检测65——Pytorch搭建DETR目标检测平台
睿智的目标检测53——Pytorch搭建YoloX目标检测平台
睿智的目标检测60——Pytorch搭建YoloV7目标检测平台
睿智的目标检测56——Pytorch搭建YoloV5目标检测平台