PyTorch:state_dict 和 parameters() 有啥区别?

Posted

技术标签:

【中文标题】PyTorch:state_dict 和 parameters() 有啥区别?【英文标题】:PyTorch: What's the difference between state_dict and parameters()?PyTorch:state_dict 和 parameters() 有什么区别? 【发布时间】:2019-07-11 19:27:16 【问题描述】:

为了在pytorch中访问模型的参数,我看到了两种方法:

使用state_dict 和使用parameters()

我想知道有什么区别,或者一个是好的做法,另一个是坏的做法。

谢谢

【问题讨论】:

【参考方案1】:

parameters() 仅给出模块参数,即权重和偏差。

返回模块参数的迭代器。

您可以查看参数列表如下:

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

另一方面,state_dict 返回一个包含整个模块状态的字典。检查其source code,其中不仅包含对parameters 的调用,还包含buffers 等。

包括参数和持久缓冲区(例如运行平均值)。键是对应的参数和缓冲区名称。

检查state_dict 包含的所有键:

model.state_dict().keys()

例如,在state_dict 中,您会发现bn1.running_meanrunning_var 之类的条目,它们在.parameters() 中不存在。


如果你只想访问参数,你可以简单地使用.parameters(),而在迁移学习中保存和加载模型等目的,你需要保存state_dict而不仅仅是参数。

【讨论】:

是否还有第三种访问参数的方法,它返回所有参数的张量,用于向量运算? 我认为.parameters() 是返回所有参数的最简单方法(可能是唯一的方法)。其他方法(如果存在)可能只涉及对其的调用。 如果您能看看我的另一个问题***.com/questions/54734556/… 并提出正确的更新规则实现方法,我将非常感激【参考方案2】:

除了@kHarshit 的答案不同之外,net.parameters() 中可训练张量的属性requires_gradTrue,而net.state_dict() 中的FalseFalse

【讨论】:

有人能解释一下为什么会有这种差异吗?此外,我实际上是在使用 torch.equal 对来自 model.named_pa​​rameters 和 model.state_dict 的张量进行映射,甚至在运行回传之前(所以没有更新)并且两者是不同的。关于为什么会这样的任何想法。

以上是关于PyTorch:state_dict 和 parameters() 有啥区别?的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch:保存模型或 state_dict 给出不同的磁盘空间占用

pytorch中model.parameters()和model.state_dict()使用时的区别

pytorch中model.parameters()和model.state_dict()使用时的区别

pytorch中model.parameters()和model.state_dict()使用时的区别

PyTorch Big Graph 嵌入数据集中优化器 state_dict 的目的是啥?

colab pytorch保存模型