pytorch笔记 torchviz
Posted UQI-LIUWJ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch笔记 torchviz相关的知识,希望对你有一定的参考价值。
画出pytorch的流程图
1 举例(Sequential中内容未命名)
from torchviz import make_dot
import torch
model = torch.nn.Sequential(
torch.nn.Linear(8, 16),
torch.nn.Tanh(),
torch.nn.Linear(16, 1))
x = torch.randn(1, 8)
y = model(x)
viz_graph=make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph.view()
# 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开
此时是因为我们没有给sequential中的module命名,所以0代表第一个module,2代表tanh之后的那个module
2 举例(Sequential+OrderedDict)
from torchviz import make_dot
from collections import OrderedDict
import torch
model = torch.nn.Sequential(OrderedDict([
('Linear1',torch.nn.Linear(8, 16)),
('Tanh',torch.nn.Tanh()),
('Linear2',torch.nn.Linear(16, 1))]))
x = torch.randn(1, 8)
y = model(x)
#make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph=make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph.view()
# 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开
3 举例(Sequential+add Module)
from torchviz import make_dot
from collections import OrderedDict
import torch
model = torch.nn.Sequential()
model.add_module('W0', torch.nn.Linear(8, 16))
model.add_module('tanh', torch.nn.Tanh())
model.add_module('W1', torch.nn.Linear(16, 1))
x = torch.randn(1, 8)
y = model(x)
#make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph=make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph.view()
# 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开
4 举例(稍微复杂一点的CNN)
from torchviz import make_dot
import torch.nn as nn
import torch
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1=nn.Sequential(
nn.Conv2d(
in_channels=1,#输入shape (1,28,28)
out_channels=16,#输出shape(16,28,28),16也是卷积核的数量
kernel_size=5,
stride=1,
padding=2),
#如果想要conv2d出来的图片长宽没有变化,那么当stride=1的时候,padding=(kernel_size-1)/2
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)#在2*2空间里面下采样,输出shape(16,14,14)
)
self.conv2=nn.Sequential(
nn.Conv2d(
in_channels=16,#输入shape (16,14,14)
out_channels=32,#输出shape(32,14,14)
kernel_size=5,
stride=1,
padding=2),#输出shape(32,7,7),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.fc=nn.Linear(32*7*7,10)#输出一个十维的东西,表示我每个数字可能性的权重
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=x.view(x.shape[0],-1)
x=self.fc(x)
return x
model=CNN()
x = torch.randn(16,1, 28,28)
y = model(x)
#make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph=make_dot(y.mean(), params=dict(model.named_parameters()))
viz_graph.view()
# 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开
以上是关于pytorch笔记 torchviz的主要内容,如果未能解决你的问题,请参考以下文章