PyTorch:state_dict 和 parameters() 有啥区别?
Posted
技术标签:
【中文标题】PyTorch:state_dict 和 parameters() 有啥区别?【英文标题】:PyTorch: What's the difference between state_dict and parameters()?PyTorch:state_dict 和 parameters() 有什么区别? 【发布时间】:2019-07-11 19:27:16 【问题描述】:为了在pytorch中访问模型的参数,我看到了两种方法:
使用state_dict
和使用parameters()
我想知道有什么区别,或者一个是好的做法,另一个是坏的做法。
谢谢
【问题讨论】:
【参考方案1】:parameters()
仅给出模块参数,即权重和偏差。
返回模块参数的迭代器。
您可以查看参数列表如下:
for name, param in model.named_parameters():
if param.requires_grad:
print(name)
另一方面,state_dict
返回一个包含整个模块状态的字典。检查其source code
,其中不仅包含对parameters
的调用,还包含buffers
等。
包括参数和持久缓冲区(例如运行平均值)。键是对应的参数和缓冲区名称。
检查state_dict
包含的所有键:
model.state_dict().keys()
例如,在state_dict
中,您会发现bn1.running_mean
和running_var
之类的条目,它们在.parameters()
中不存在。
如果你只想访问参数,你可以简单地使用.parameters()
,而在迁移学习中保存和加载模型等目的,你需要保存state_dict
而不仅仅是参数。
【讨论】:
是否还有第三种访问参数的方法,它返回所有参数的张量,用于向量运算? 我认为.parameters()
是返回所有参数的最简单方法(可能是唯一的方法)。其他方法(如果存在)可能只涉及对其的调用。
如果您能看看我的另一个问题***.com/questions/54734556/… 并提出正确的更新规则实现方法,我将非常感激【参考方案2】:
除了@kHarshit 的答案不同之外,net.parameters()
中可训练张量的属性requires_grad
是True
,而net.state_dict()
中的False
是False
【讨论】:
有人能解释一下为什么会有这种差异吗?此外,我实际上是在使用 torch.equal 对来自 model.named_parameters 和 model.state_dict 的张量进行映射,甚至在运行回传之前(所以没有更新)并且两者是不同的。关于为什么会这样的任何想法。以上是关于PyTorch:state_dict 和 parameters() 有啥区别?的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch:保存模型或 state_dict 给出不同的磁盘空间占用
pytorch中model.parameters()和model.state_dict()使用时的区别
pytorch中model.parameters()和model.state_dict()使用时的区别
pytorch中model.parameters()和model.state_dict()使用时的区别