Pytroch同一个优化器优化多个模型的参数并且保存优化后的参数
Posted 像风一样自由的小周
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytroch同一个优化器优化多个模型的参数并且保存优化后的参数相关的知识,希望对你有一定的参考价值。
在进行深度学习过程中会遇到几个模型进行串联,这几个模型需要使用同一个优化器,但每个模型的学习率或者动量等其他参数不一样这种情况。一种解决方法是新建一个模型将这几个模型进行串联,另一种解决方法便是往优化器里面传入这几个模型的参数。
一、参考链接
二、Pytroch同一个优化器载入多个模型的参数
为了方便介绍,这里面取两个模型进行说明,多个模型可以以此类推。
假设:1.采用的优化器为Adam,2.两个模型名称分别为Encoder(lr=0.01)、Decoder(lr=0.0001)。则代码如下:
optim.SGD([
'params': Encoder.parameters(),
'params': Decoder.parameters(), 'lr': 1e-4
], lr=1e-2, momentum=0.9)
这意味着Encoder
的参数将会使用1e-2
的学习率,Decoder
的参数将会使用1e-4
的学习率,并且0.9
的momentum
将会被用于所 有的参数。
个人觉得这个地方不熟悉是因为对pytroch的optimizer如何实现不熟悉,下面为pytroch的optimizer的初始化代码:
class Optimizer(object):
r"""Base class for all optimizers.
.. warning::
Parameters need to be specified as collections that have a deterministic
ordering that is consistent between runs. Examples of objects that don't
satisfy those properties are sets and iterators over values of dictionaries.
需要将参数指定为具有在运行之间一致的确定性排序的集合。不满足这些属性的对象示例是字典值的集合和迭代器。
Arguments:
params (iterable): an iterable of :class:`torch.Tensor` s or
:class:`dict` s. Specifies what Tensors should be optimized.
defaults: (dict): a dict containing default values of optimization
options (used when a parameter group doesn't specify them).
params (iterable):torch.Tensor 或 dict 的可迭代对象。指定应该优化哪些张量。
默认值:(dict):包含优化选项默认值的字典(在参数组未指定它们时使用)。
"""
def __init__(self, params, defaults):
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
self.state = defaultdict(dict)
self.param_groups = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = ['params': param_groups]
for param_group in param_groups:
self.add_param_group(param_group)
三、保存并加载多个模型的参数
这里取两个模型进行说明(Encoder和Decoder)
保存
state = 'Encoder':Encoder.state_dict(),
'Decoder':Decoder.state_dict()
torch.save(state, filename) # state为一个字典,filename为保存的文件名称.pth
加载
checkpoint = torch.load(filename)
Encoder.load_state_dict(checkpoint['Encoder']) # 这里checkpoint可以看为字典,和之前保存的state相对应
Decoder.load_state_dict(checkpoint['Decoder'])
以上是关于Pytroch同一个优化器优化多个模型的参数并且保存优化后的参数的主要内容,如果未能解决你的问题,请参考以下文章