pytorch模型参数
Posted lucifer1997
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch模型参数相关的知识,希望对你有一定的参考价值。
1、torch.nn.state_dict():
返回一个字典,保存着module的所有状态(state)。
parameters和persistent_buffers都会包含在字典中,字典的key就是parameter和buffer的names。
例子:
import torch from torch.autograd import Variable import torch.nn as nn class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv2 = nn.Linear(1, 2) self.vari = Variable(torch.rand([1])) self.par = nn.Parameter(torch.rand([1])) self.register_buffer("buffer", torch.randn([2,3])) model = Model() print(model.state_dict().keys())
odict_keys([‘par‘, ‘buffer‘, ‘conv2.weight‘, ‘conv2.bias‘])
字典迭代形式<class ‘str‘>:<class ‘torch.Tensor‘>, ...
以上是关于pytorch模型参数的主要内容,如果未能解决你的问题,请参考以下文章