如何以原生 Pytorch 格式定义模型并导入 LightningModule,无需复制和粘贴?
Posted
技术标签:
【中文标题】如何以原生 Pytorch 格式定义模型并导入 LightningModule,无需复制和粘贴?【英文标题】:How to Define Model in Native Pytorch Format and Import Into LightningModule Without Copy and Pasting? 【发布时间】:2021-11-29 18:53:16 【问题描述】:假设我有一个像这样的原生 pytorch 模型
class NormalAutoEncoder(nn.Module)):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
如何在不复制粘贴的情况下将__init__
和forward
功能(基本上是全网)放到pytorch光照模块中?
【问题讨论】:
【参考方案1】:简单。利用 Python 的继承机制。
如果以下是原生 PyTorch 模块
class NormalAutoEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = ...
self.decoder = ...
def forward(self, x):
embedding = ...
return embedding
然后让你的新LightningAutoEncoder
也继承自NormalAutoEncoder
class LightningAutoEncoder(LightningModule, NormalAutoEncoder):
def __init__(self, ...):
LightningModule.__init__(self) # only LightningModule's init
NormalAutoEncoder.__init__(self, ...) # this basically executes __init__() of the NormalAutoEncoder
def forward(self, x):
# offloads its execution to NormalAutoEncoder's forward() method
return NormalAutoEncoder.forward(self, x)
就是这样。禁止复制粘贴。
【讨论】:
很好的解决方案!但是,如果不将 NormalAutoEncoder 类作为函数参数传递,有没有办法做到这一点? 它不是函数参数,它是继承。你的要求到底是什么?为什么不想继承?在这种情况下,您只需创建一个NormalAutoEncoder
实例作为 Lightning 类的属性。以上是关于如何以原生 Pytorch 格式定义模型并导入 LightningModule,无需复制和粘贴?的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch 和Albumentations 在图像分割的应用
[Pytorch系列-43]:工具集 - torchvision预训练模型参数的导入(以ResNet为例)