PyTorch 中的 register_parameter 和 register_buffer 有啥区别?

Posted

技术标签:

【中文标题】PyTorch 中的 register_parameter 和 register_buffer 有啥区别?【英文标题】:What is the difference between register_parameter and register_buffer in PyTorch?PyTorch 中的 register_parameter 和 register_buffer 有什么区别? 【发布时间】:2019-12-23 17:30:07 【问题描述】:

模块的parameters在训练过程中发生变化,也就是说,它们是在神经网络训练过程中学到的,但是buffer是什么?

它是在神经网络训练期间学习的吗?

【问题讨论】:

【参考方案1】:

Pytorch doc for register_buffer() 方法读取

这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm 的 running_mean 不是参数,而是持久状态的一部分。

正如您已经观察到的,模型参数是在训练过程中使用 SGD 学习和更新的。 但是,有时还有其他量是模型“状态”的一部分,应该是 - 保存为state_dict 的一部分。 - 将模型的其余参数移至 cuda()cpu()。 - 使用模型的其余参数转换为 float/half/double。 将这些“参数”注册为模型的 buffer 允许 pytorch 跟踪它们并像常规参数一样保存它们,但会阻止 pytorch 使用 SGD 机制更新它们。

可以在_BatchNorm 模块中找到缓冲区的示例,其中running_meanrunning_varnum_batches_tracked 被注册为缓冲区,并通过累积通过层转发的数据的统计信息来更新。这与使用常规 SGD 优化学习数据仿射变换的 weightbias 参数形成对比。

【讨论】:

【参考方案2】:

您为模块创建的参数和缓冲区 (nn.Module)。

假设你有一个线性层nn.Linear。您已经有 weightbias 参数。但是如果你需要一个新参数,你可以使用register_parameter() 来注册一个新的命名参数,它是一个张量。

当你注册一个新参数时,它会出现在 module.parameters() 迭代器中,但是当你注册一个缓冲区时它不会出现。

区别:

Buffers 被命名为张量,它们不会像参数一样在每一步都更新梯度。 对于缓冲区,您可以创建自定义逻辑(完全由您决定)。

好消息是当您保存模型时,所有参数和缓冲区都会被保存,当您将模型移入或移出 CUDA 参数和缓冲区时也会随之消失。

【讨论】:

以上是关于PyTorch 中的 register_parameter 和 register_buffer 有啥区别?的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch Note37 PyTorch 中的循环神经网络模块

Pytorch中的Conv1d()函数

pytorch中的数据加载(dataset基类,以及pytorch自带数据集)

pytorch torch类

Pytorch 中的 dim

PyTorch从入门到精通100讲-神经网络在pytorch中的应用