pytorch笔记 torchviz

Posted UQI-LIUWJ

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch笔记 torchviz相关的知识,希望对你有一定的参考价值。

画出pytorch的流程图

1 举例(Sequential中内容未命名)

from torchviz import make_dot
import torch

model = torch.nn.Sequential(
  torch.nn.Linear(8, 16),
  torch.nn.Tanh(),
  torch.nn.Linear(16, 1))

x = torch.randn(1, 8)
y = model(x)


viz_graph=make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph.view()
# 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开

 此时是因为我们没有给sequential中的module命名,所以0代表第一个module,2代表tanh之后的那个module

2 举例(Sequential+OrderedDict)

from torchviz import make_dot
from collections import OrderedDict
import torch

model = torch.nn.Sequential(OrderedDict([
  ('Linear1',torch.nn.Linear(8, 16)),
  ('Tanh',torch.nn.Tanh()),
  ('Linear2',torch.nn.Linear(16, 1))]))

x = torch.randn(1, 8)
y = model(x)

#make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph=make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph.view()
# 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开

 3 举例(Sequential+add Module)

from torchviz import make_dot
from collections import OrderedDict
import torch

model = torch.nn.Sequential()

model.add_module('W0', torch.nn.Linear(8, 16))
model.add_module('tanh', torch.nn.Tanh())
model.add_module('W1', torch.nn.Linear(16, 1))

x = torch.randn(1, 8)
y = model(x)

#make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph=make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph.view()
# 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开

4 举例(稍微复杂一点的CNN)

 

from torchviz import make_dot
import torch.nn as nn
import torch


class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
 
        self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=1,#输入shape (1,28,28)
                out_channels=16,#输出shape(16,28,28),16也是卷积核的数量
                kernel_size=5,
                stride=1,
                padding=2),
#如果想要conv2d出来的图片长宽没有变化,那么当stride=1的时候,padding=(kernel_size-1)/2
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)#在2*2空间里面下采样,输出shape(16,14,14)
        )
           
        self.conv2=nn.Sequential(
            nn.Conv2d(
                in_channels=16,#输入shape (16,14,14)
                out_channels=32,#输出shape(32,14,14)
                kernel_size=5,
                stride=1,
                padding=2),#输出shape(32,7,7),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
 
        self.fc=nn.Linear(32*7*7,10)#输出一个十维的东西,表示我每个数字可能性的权重
        
    def forward(self,x):
            x=self.conv1(x)
            x=self.conv2(x)
            x=x.view(x.shape[0],-1)
            x=self.fc(x)
            return x
model=CNN()

x = torch.randn(16,1, 28,28)
y = model(x)

#make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph=make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph.view()
# 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开

 

以上是关于pytorch笔记 torchviz的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch学习笔记:模型定义修改保存

Pytorch实战笔记

PyTorch学习笔记:PyTorch可视化

pytorch学习笔记

Pytorch学习笔记:基本概念安装张量操作逻辑回归

pytorch笔记01-数据增强