PyTorch中nn.Module类中__call__方法介绍

Posted fengbingchun

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch中nn.Module类中__call__方法介绍相关的知识,希望对你有一定的参考价值。

      在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下:

__call__ : Callable[…, Any] = _call_impl
forward: Callable[…, Any] = _forward_unimplemented

      在PyTorch中nn.Module类是所有神经网络模块的基类,你的网络也应该继承这个类,需要重载__init__和forward函数。以下是仿照PyTorch中Module和AlexNet类实现写的假的实现的测试代码:

from typing import Callable, Any, List

def _forward_unimplemented(self, *input: Any) -> None:
    "Should be overridden by all subclasses"
    print("_forward_unimplemented")
    raise NotImplementedError

class Module:
    def __init__(self):
        print("Module.__init__")

    forward: Callable[..., Any] = _forward_unimplemented

    def _call_impl(self, *input, **kwargs):
        print("Module._call_impl")
        result = self.forward(*input, **kwargs)
        return result

    __call__: Callable[..., Any] = _call_impl

    def cpu(self):
        print("Module.cpu")

class AlexNet(Module):
    def __init__(self):
        print("AlexNet.__init__")
        super(AlexNet, self).__init__()

    def forward(self, x):
        print("AlexNet.forward")
        return x

model = AlexNet()
x: List[int] = [1, 2, 3, 4]
print("result:", model(x))

model.cpu()

print("test finish")

      执行model(x)语句时,会调用AlexNet的forward函数,是因为AlexNet的父类Module中的__call__函数:首先Module中有__call__方法,因此model(x)这条语句可以正常执行。Module中并没有直接给出__call__的实现体,而是__call__后紧跟冒号,此冒号表示类型注解;后面的Callable和Any是typing模块中的,Callable表示可调用类型,即等号右边应该是一个可调用类型,此处指的是_call_impl;Any是一种特殊的类型,它与所有类型兼容;Callable[…, Any]表示_call_impl可接受任意数量的参数并返回Any。这里__call__实际指向了_call_impl函数,因此调用__call__实际是调用_call_impl。

      typing模块的介绍参考:https://blog.csdn.net/fengbingchun/article/details/122288737

      _call_impl函数体内会调用forward,Module中的forward的实现方式与__call__相同,但是_forward_unimplemented函数并没有实现体,调用它会触发Error即NotImplementedError。因此在子类AlexNet中一定要给出forward的具体实现,否则调用的将是_forward_unimplemented。

      测试代码执行结果如下:

      如果注释掉AlexNet中的forward,则执行结果如下:

      GitHub: https://github.com/fengbingchun/PyTorch_Test

以上是关于PyTorch中nn.Module类中__call__方法介绍的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 参数初始化

Django 是不是可以 调用pytorch 知乎

『PyTorch』第七弹_nn.Module扩展层

『PyTorch』第十四弹_torch.nn.Module深入分析

每天讲解一点PyTorch 14模型定义,继承nn.Module

每天讲解一点PyTorch 14模型定义,继承nn.Module