什么时候需要 pytorch 自定义功能(而不仅仅是一个模块)?
Posted
技术标签:
【中文标题】什么时候需要 pytorch 自定义功能(而不仅仅是一个模块)?【英文标题】:when is a pytorch custom function needed (rather than only a module)? 【发布时间】:2017-11-09 18:13:26 【问题描述】:Pytorch 初学者在这里!考虑以下自定义模块:
class Testme(nn.Module):
def __init__(self):
super(Testme, self).__init__()
def forward(self, x):
return x / t_.max(x).expand_as(x)
据我了解文档:
我相信这也可以作为自定义Function
来实现。 Function
的子类需要 backward()
方法,但是
Module
没有。同样,在线性Module
的文档示例中,它依赖于线性Function
:
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
...
def forward(self, input):
return Linear()(input, self.weight, self.bias)
问题:我不明白Module
和Function
之间的关系。在上面的第一个清单(模块Testme
)中,它应该有关联的功能吗?如果没有,那么可以通过子类化 Module 来实现这个没有 backward
方法,那么为什么 Function
总是需要 backward
方法呢?
也许Function
s 仅适用于不是由现有 Torch 函数组成的函数?换一种说法:如果模块的forward
方法完全由先前定义的torch 函数组成,也许模块不需要关联的Function
?
【问题讨论】:
【参考方案1】:这些信息是从官方 PyTorch 文档中收集和总结的。
torch.autograd.Function
really 是 PyTorch 中 autograd 包的核心。您在 PyTorch 中构建的任何图形以及您在 PyTorch 中对 Variables
执行的任何操作都基于 Function
。任何函数都需要 __init__(), forward()
和 backward()
方法(在此处查看更多信息:http://pytorch.org/docs/notes/extending.html)。这使 PyTorch 能够计算结果并计算 Variables
的梯度。
nn.Module()
相比之下实际上只是为了方便组织你的模型、你的不同层等。例如,它将你模型中所有可训练的参数组织在.parameters()
中,并允许你在模型中添加另一个层这不是定义后向方法的地方,因为在forward()
方法中,您应该使用Function()
的子类,您已经为此定义backward()
。因此,如果您在 forward()
中指定了操作顺序,PyTorch 已经知道如何反向传播梯度。
现在,你应该什么时候使用?
如果您的操作只是 PyTorch 中现有已实现函数的组合(如您上面的内容),那么您自己向 Function() 添加任何子类确实没有意义。因为您可以将操作堆叠起来并构建动态图。然而,将这些操作捆绑在一起是一个明智的想法。如果任何操作涉及可训练参数(例如神经网络的线性层),您应该将nn.Module()
子类化并在 forward 方法中将您的操作捆绑在一起。这使您可以轻松访问参数(如上所述)以使用torch.optim
等。如果您没有任何可训练的参数,我可能仍然会将它们捆绑在一起,但是一个标准的 Python 函数,您可以在其中处理您使用的每个操作的实例化就足够了。
如果你有一个新的自定义操作(例如一个带有一些复杂采样过程的新随机层),你应该继承 Function()
并定义 __init__(), forward()
和 backward()
来告诉 PyTorch 如何计算结果以及如何计算梯度, 当你使用这个操作时。之后,如果您的操作具有可训练的参数,您应该创建一个函数版本来处理实例化函数并使用您的操作或创建一个模块。同样,您可以在上面的链接中阅读更多相关信息。
【讨论】:
很好的答案,谢谢。我不明白一件事,这句话:只是将它们捆绑在一个函数中,您在其中参与每个操作的实例化。 1.我假设“函数”是指python函数,而不是torch函数? 2.这个我就是看不懂意思:“你在哪里参与每个操作的实例化。” 1.对,就是这样。将它们捆绑在一起,我的意思是您将其用作包装器的 python 函数。这是因为torch.autograd.Function
的所有子类都是类。因此,您需要先创建它们的实例,然后才能使用它们。假设你实现了 class My_OP(torch.autograd.Function): .....,你会像这样使用它:My_OP()(input)。特别是当操作需要初始化参数时,将其放在 Pyhton 函数中会更方便,即 def my_op(input): return My_OP()(input)。这在上面的链接中也得到了很好的解释。 2.意思是小心不要参与以上是关于什么时候需要 pytorch 自定义功能(而不仅仅是一个模块)?的主要内容,如果未能解决你的问题,请参考以下文章