如何以原生 Pytorch 格式定义模型并导入 LightningModule,无需复制和粘贴?

Posted

技术标签:

【中文标题】如何以原生 Pytorch 格式定义模型并导入 LightningModule,无需复制和粘贴?【英文标题】:How to Define Model in Native Pytorch Format and Import Into LightningModule Without Copy and Pasting? 【发布时间】:2021-11-29 18:53:16 【问题描述】:

假设我有一个像这样的原生 pytorch 模型

class NormalAutoEncoder(nn.Module)):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

如何在不复制粘贴的情况下将__init__forward功能(基本上是全网)放到pytorch光照模块中?

【问题讨论】:

【参考方案1】:

简单。利用 Python 的继承机制。

如果以下是原生 PyTorch 模块

class NormalAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ...
        self.decoder = ...

    def forward(self, x):
        embedding = ...
        return embedding

然后让你的新LightningAutoEncoder 继承自NormalAutoEncoder

class LightningAutoEncoder(LightningModule, NormalAutoEncoder):
    def __init__(self, ...):
        LightningModule.__init__(self) # only LightningModule's init
        NormalAutoEncoder.__init__(self, ...) # this basically executes __init__() of the NormalAutoEncoder

    def forward(self, x):
        # offloads its execution to NormalAutoEncoder's forward() method
        return NormalAutoEncoder.forward(self, x)

就是这样。禁止复制粘贴。

【讨论】:

很好的解决方案!但是,如果不将 NormalAutoEncoder 类作为函数参数传递,有没有办法做到这一点? 它不是函数参数,它是继承。你的要求到底是什么?为什么不想继承?在这种情况下,您只需创建一个 NormalAutoEncoder 实例作为 Lightning 类的属性。

以上是关于如何以原生 Pytorch 格式定义模型并导入 LightningModule,无需复制和粘贴?的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch 和Albumentations 在图像分割的应用

如何将 Pytorch 模型导入 MATLAB

[Pytorch系列-43]:工具集 - torchvision预训练模型参数的导入(以ResNet为例)

PyTorch 和 Albumentations 实现图像分类(猫狗大战)

Pytorch练手项目二——模型微调

深度学习pytorch训练代码模板(个人习惯)