Pytorch 设备和 .to(device) 方法

Posted

技术标签:

【中文标题】Pytorch 设备和 .to(device) 方法【英文标题】:Pytorch device and .to(device) method 【发布时间】:2020-06-28 00:38:11 【问题描述】:

我正在尝试学习 RNN 和 Pytorch。

所以我看到了一些 RNN 的代码,在前向检验方法中,他们做了这样的检查:

def forward(self, inputs, hidden):
    if inputs.is_cuda:
        device = inputs.get_device()
    else:
        device = torch.device("cpu")
    embed_out = self.embeddings(inputs)
    logits = torch.zeros(self.seq_len, self.batch_size, self.vocab_size).to(device)

我认为检查的重点是看看我们是否可以在更快的 GPU 而不是 CPU 上运行代码? 为了进一步理解代码,我做了以下操作:

ex= torch.zeros(3,10,5)
ex1= torch.tensor(np.array([[0,0,0,1,0], [1,0,0,0,0],[0,1,0,0,0]]))

print(ex)
print("device is")
print(ex1.get_device())
print(ex.to(ex1.get_device()))

输出是:

        ...
        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])
device is
-1
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-b09342e2ba0f> in <module>()
     67 print("device is")
     68 print(ex1.get_device())
---> 69 print(ex.to(ex1.get_device()))

RuntimeError: Device index must not be negative

我看不懂代码中的“设备”,也看不懂.to(device)方法。你能帮我理解一下吗?

【问题讨论】:

这能回答你的问题吗? Using CUDA with pytorch? 你有 GPU 吗?你能告诉我命令torch.cuda.is_available()的输出吗? 【参考方案1】:

此代码已弃用。做吧:

def forward(self, inputs, hidden):
    embed_out = self.embeddings(inputs)
    logits = torch.zeros((self.seq_len, self.batch_size, self.vocab_size), device=inputs.device)

请注意,如果张量已经在请求的设备上,to(device) 是免费的。并且不要使用get_device(),而是使用device 属性。开箱即用的 cpu 和 gpu 运行良好。

另外,请注意torch.tensor(np.array(...)) 是一种不好的做法,原因有很多。首先,要将 numpy 数组转换为 Torch 张量,请使用as_tensorfrom_numpy。然后,您将获得一个具有默认 numpy dtype 而不是火炬的张量。在这种情况下,它是相同的(int64),但对于 float,它会有所不同。最后torch.tensor可以用list初始化,就像numpy数组一样,完全摆脱numpy,直接调用torch。

【讨论】:

以上是关于Pytorch 设备和 .to(device) 方法的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch:tensor.cuda() 和 tensor.to(torch.device("cuda:0")) 有啥区别?

pytorch介绍和环境配置

Pytorch Lighting 的 model.to(device)

PyTorch Lightning 将张量移动到 validation_epoch_end 中的正确设备

pytorch 多GPU 训练

PyTorch 60 分钟入门教程:数据并行处理