PyTorch 中的 register_parameter 和 register_buffer 有啥区别?
Posted
技术标签:
【中文标题】PyTorch 中的 register_parameter 和 register_buffer 有啥区别?【英文标题】:What is the difference between register_parameter and register_buffer in PyTorch?PyTorch 中的 register_parameter 和 register_buffer 有什么区别? 【发布时间】:2019-12-23 17:30:07 【问题描述】:模块的parameters在训练过程中发生变化,也就是说,它们是在神经网络训练过程中学到的,但是buffer是什么?
它是在神经网络训练期间学习的吗?
【问题讨论】:
【参考方案1】:Pytorch doc for register_buffer()
方法读取
这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm 的
running_mean
不是参数,而是持久状态的一部分。
正如您已经观察到的,模型参数是在训练过程中使用 SGD 学习和更新的。
但是,有时还有其他量是模型“状态”的一部分,应该是
- 保存为state_dict
的一部分。
- 将模型的其余参数移至 cuda()
或 cpu()
。
- 使用模型的其余参数转换为 float
/half
/double
。
将这些“参数”注册为模型的 buffer
允许 pytorch 跟踪它们并像常规参数一样保存它们,但会阻止 pytorch 使用 SGD 机制更新它们。
可以在_BatchNorm
模块中找到缓冲区的示例,其中running_mean
、running_var
和num_batches_tracked
被注册为缓冲区,并通过累积通过层转发的数据的统计信息来更新。这与使用常规 SGD 优化学习数据仿射变换的 weight
和 bias
参数形成对比。
【讨论】:
【参考方案2】:您为模块创建的参数和缓冲区 (nn.Module
)。
假设你有一个线性层nn.Linear
。您已经有 weight
和 bias
参数。但是如果你需要一个新参数,你可以使用register_parameter()
来注册一个新的命名参数,它是一个张量。
当你注册一个新参数时,它会出现在 module.parameters()
迭代器中,但是当你注册一个缓冲区时它不会出现。
区别:
Buffers 被命名为张量,它们不会像参数一样在每一步都更新梯度。 对于缓冲区,您可以创建自定义逻辑(完全由您决定)。
好消息是当您保存模型时,所有参数和缓冲区都会被保存,当您将模型移入或移出 CUDA 参数和缓冲区时也会随之消失。
【讨论】:
以上是关于PyTorch 中的 register_parameter 和 register_buffer 有啥区别?的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch Note37 PyTorch 中的循环神经网络模块