Pytorch 设备和 .to(device) 方法
Posted
技术标签:
【中文标题】Pytorch 设备和 .to(device) 方法【英文标题】:Pytorch device and .to(device) method 【发布时间】:2020-06-28 00:38:11 【问题描述】:我正在尝试学习 RNN 和 Pytorch。
所以我看到了一些 RNN 的代码,在前向检验方法中,他们做了这样的检查:
def forward(self, inputs, hidden):
if inputs.is_cuda:
device = inputs.get_device()
else:
device = torch.device("cpu")
embed_out = self.embeddings(inputs)
logits = torch.zeros(self.seq_len, self.batch_size, self.vocab_size).to(device)
我认为检查的重点是看看我们是否可以在更快的 GPU 而不是 CPU 上运行代码? 为了进一步理解代码,我做了以下操作:
ex= torch.zeros(3,10,5)
ex1= torch.tensor(np.array([[0,0,0,1,0], [1,0,0,0,0],[0,1,0,0,0]]))
print(ex)
print("device is")
print(ex1.get_device())
print(ex.to(ex1.get_device()))
输出是:
...
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]])
device is
-1
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-2-b09342e2ba0f> in <module>()
67 print("device is")
68 print(ex1.get_device())
---> 69 print(ex.to(ex1.get_device()))
RuntimeError: Device index must not be negative
我看不懂代码中的“设备”,也看不懂.to(device)
方法。你能帮我理解一下吗?
【问题讨论】:
这能回答你的问题吗? Using CUDA with pytorch? 你有 GPU 吗?你能告诉我命令torch.cuda.is_available()
的输出吗?
【参考方案1】:
此代码已弃用。做吧:
def forward(self, inputs, hidden):
embed_out = self.embeddings(inputs)
logits = torch.zeros((self.seq_len, self.batch_size, self.vocab_size), device=inputs.device)
请注意,如果张量已经在请求的设备上,to(device)
是免费的。并且不要使用get_device()
,而是使用device
属性。开箱即用的 cpu 和 gpu 运行良好。
另外,请注意torch.tensor(np.array(...))
是一种不好的做法,原因有很多。首先,要将 numpy 数组转换为 Torch 张量,请使用as_tensor
或from_numpy
。然后,您将获得一个具有默认 numpy dtype 而不是火炬的张量。在这种情况下,它是相同的(int64),但对于 float,它会有所不同。最后torch.tensor
可以用list初始化,就像numpy数组一样,完全摆脱numpy,直接调用torch。
【讨论】:
以上是关于Pytorch 设备和 .to(device) 方法的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch:tensor.cuda() 和 tensor.to(torch.device("cuda:0")) 有啥区别?
Pytorch Lighting 的 model.to(device)