pytorch 笔记:torchsummary

Posted UQI-LIUWJ

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 笔记:torchsummary相关的知识,希望对你有一定的参考价值。

作用:打印神经网络的结构

pytorch笔记:搭建简易CNN_UQI-LIUWJ的博客-CSDN博客 中搭建的CNN为例

import torch
from torchsummary import summary

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
 
        self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=1,
#输入shape (1,28,28)
                out_channels=16,
#输出shape(16,28,28),16也是卷积核的数量
                kernel_size=5,
                stride=1,
                padding=2),
#如果想要conv2d出来的图片长宽没有变化,那么当stride=1的时候,padding=(kernel_size-1)/2
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
 #在2*2空间里面下采样,输出shape(16,14,14)
        )
           
        self.conv2=nn.Sequential(
            nn.Conv2d(
                in_channels=16,
#输入shape (16,14,14)
                out_channels=32,
#输出shape(32,14,14)
                kernel_size=5,
                stride=1,
                padding=2),
#输出shape(32,7,7),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
 
        self.fc=nn.Linear(32*7*7,10)
#输出一个十维的东西,表示我每个数字可能性的权重
        
    def forward(self,x):
            x=self.conv1(x)
            x=self.conv2(x)
            x=x.view(x.shape[0],-1)
            x=self.fc(x)
            return x
    
cnn=CNN()
summary(cnn,(1,28,28))

输出的结果是这样的:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 28, 28]             416
              ReLU-2           [-1, 16, 28, 28]               0
         MaxPool2d-3           [-1, 16, 14, 14]               0
            Conv2d-4           [-1, 32, 14, 14]          12,832
              ReLU-5           [-1, 32, 14, 14]               0
         MaxPool2d-6             [-1, 32, 7, 7]               0
            Linear-7                   [-1, 10]          15,690
================================================================
Total params: 28,938
Trainable params: 28,938
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.32
Params size (MB): 0.11
Estimated Total Size (MB): 0.44
----------------------------------------------------------------

以上是关于pytorch 笔记:torchsummary的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch 打印模型结构输出维度和参数信息(torchsummary)

pytorch笔记:VGG 16

PyTorch打印模型结构输出维度和参数信息(torchsummary)

Pytorch Note57 Pytorch可视化网络结构

错误处理:RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be

pytorch用于对比试验,caffe用于工程落地