pytorch模型参数

Posted lucifer1997

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch模型参数相关的知识,希望对你有一定的参考价值。

1、torch.nn.state_dict():

返回一个字典,保存着module的所有状态(state)。

parameters和persistent_buffers都会包含在字典中,字典的key就是parameter和buffer的names。

例子:

import torch
from torch.autograd import Variable
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv2 = nn.Linear(1, 2)
        self.vari = Variable(torch.rand([1]))
        self.par = nn.Parameter(torch.rand([1]))
        self.register_buffer("buffer", torch.randn([2,3]))

model = Model()
print(model.state_dict().keys())
odict_keys([par, buffer, conv2.weight, conv2.bias])

 

字典迭代形式<class ‘str‘>:<class ‘torch.Tensor‘>, ...

以上是关于pytorch模型参数的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch 之 模型的保存与调用

Pytorch冻结部分层的参数

[Pytorch]Pytorch 保存模型与加载模型(转)

PyTorch参数模型转换为PT模型

pytorch模型文件pth详解

pytorch自动删除之前保存的pt文件