Pytorch——任意多卡GPU运行网络
Posted William.csj
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch——任意多卡GPU运行网络相关的知识,希望对你有一定的参考价值。
原来:
# 原来:
# self.model_with_loss = DataParallel(
# self.model_with_loss, device_ids=gpus,
# chunk_sizes=chunk_sizes).to(device)
现在:
# 现在:
# 任意多块GPU运行 可行方法1:
self.model_with_loss = DataParallel(
self.model_with_loss, device_ids=gpus,
chunk_sizes=chunk_sizes)
device = torch.device(gpus[0]) # 指定输入的第一块GPU为主设备
self.model_with_loss.to(device)
# 任意多块GPU运行 可行方法2:
# self.model_with_loss = DataParallel(
# self.model_with_loss, device_ids=[2,3],
# chunk_sizes=chunk_sizes)
# device = torch.device("cuda:2" ) # 指定GPU2为主设备
# self.model_with_loss.to(device)
完整代码:
class ModleWithLoss(torch.nn.Module):
def __init__(self, model, loss):
super(ModleWithLoss, self).__init__()
self.model = model
self.loss = loss
def forward(self, batch):
outputs = self.model(batch['input'])
loss, loss_stats = self.loss(outputs, batch)
return outputs[-1], loss, loss_stats
class BaseTrainer(object):
def __init__(
self, opt, model, optimizer=None):
self.opt = opt
self.optimizer = optimizer
self.loss_stats, self.loss = self._get_losses(opt)
self.model_with_loss = ModleWithLoss(model, self.loss)
self.optimizer.add_param_group('params': self.loss.parameters())
def set_device(self, gpus, chunk_sizes, device):
if len(gpus) > 1: # 多GPU
# 原来:
# self.model_with_loss = DataParallel(
# self.model_with_loss, device_ids=gpus,
# chunk_sizes=chunk_sizes).to(device)
# 现在:
# ------------------------------------------
# 任意多块GPU运行 可行方法1:
self.model_with_loss = DataParallel(
self.model_with_loss, device_ids=gpus,
chunk_sizes=chunk_sizes)
device = torch.device(gpus[0]) # 指定输入的第一块GPU为主设备
self.model_with_loss.to(device)
# 任意多块GPU运行 可行方法2:
# self.model_with_loss = DataParallel(
# self.model_with_loss, device_ids=[2,3],
# chunk_sizes=chunk_sizes)
# device = torch.device("cuda:2" ) # 指定GPU2为主设备
# self.model_with_loss.to(device)
# ------------------------------------------
else: # 单GPU
self.model_with_loss = self.model_with_loss.to(device)
参考资料
- 解决RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0])
- [PyTorch 学习笔记] 7.3 使用 GPU 训练模型
以上是关于Pytorch——任意多卡GPU运行网络的主要内容,如果未能解决你的问题,请参考以下文章