Pytorch DataParallel 与自定义模型

Posted

技术标签:

【中文标题】Pytorch DataParallel 与自定义模型【英文标题】:Pytorch DataParallel with custom model 【发布时间】:2021-12-09 18:57:12 【问题描述】:

我想用多个 gpu 训练模型。我正在使用以下代码

model = load_model(path)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)

它运行良好,只是 DataParallel 不包含原始模型中的函数,有没有办法解决它?谢谢

【问题讨论】:

"DataParallel 不包含原始模型中的函数",你到底是什么意思? @Ivan 我对 ML 很陌生,它是 VQGan 模型,它包含 VectorQuantizer 作为 self.quantize 属性,当我们执行“model = nn.DataParallel(model)”时它丢失了 您好,既然有 pytorch-lightning 的标签,您想查看那里的多 GPU 文档吗? pytorch-lightning.readthedocs.io/en/stable/advanced/… @NanoBit 谢谢,是的模型继承了pl.LightningModule 请澄清您的具体问题或提供其他详细信息以准确突出您的需求。正如目前所写的那样,很难准确地说出你在问什么。 【参考方案1】:

传递给nn.DataParallelnn.Module 最终将被类包装以处理数据并行性。您仍然可以使用 module 属性访问您的模型。

>>> p_model = nn.DataParallel(model)
>>> p_model.module # <- model

例如,要访问底层模型的 quantize 属性,您可以:

>>> p_model.module.quantize

【讨论】:

以上是关于Pytorch DataParallel 与自定义模型的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch-4 nn.DataParallel 数据并行详解

pytorch分布式训练(DataParallel/DistributedDataParallel)

pytorch分布式训练(DataParallel/DistributedDataParallel)

pytorch分布式训练(DataParallel/DistributedDataParallel)

每天讲解一点PyTorch 18多卡训练torch.nn.DataParallel

每天讲解一点PyTorch 18多卡训练torch.nn.DataParallel