PyTorch学习Tensor与Numpy数组的相互转换

Posted Daylight..

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch学习Tensor与Numpy数组的相互转换相关的知识,希望对你有一定的参考价值。

torch的Tensor转化为NumPy数组

Tensor和NumPy数组共享相同的底层储存位置,因此当一个改变时,另一个也会改变。
创建tensor
首先导入torch和numpy包

import torch
import numpy as np

创建一个tensor

a = torch.zeros(5)
print(a)

输出结果:

tensor([0., 0., 0., 0., 0.])

tensor转换为numpy

b = a.numpy()
print(b)

输出结果:

[0. 0. 0. 0. 0.]

当所转换的tensor发生变化时,numpy数组也会改变。

a.add_(1)
print(a)
print(b)

输出结果:

tensor([1., 1., 1., 1., 1.])
[1. 1. 1. 1. 1.]

NumPy数组转化为Tensor

创建一个numpy数组

a = np.zeros(5)
print(a)

输出结果:

[0. 0. 0. 0. 0.]

numpy数组转换为tensor

b = torch.from_numpy(a)
print(b)

输出结果:

tensor([0., 0., 0., 0., 0.], dtype=torch.float64)

当被转换numpy数组发展变化,转换得到的tensor也会发生变化。

np.add(a, 2,out=a)
print(a)
print(b)

输出结果:

[2. 2. 2. 2. 2.]
tensor([2., 2., 2., 2., 2.], dtype=torch.float64)

CPU上的所有tensor(CharTensor除外)都支持与Numpy数组的相互转换。

CUDA上的Tensor

张量可以使用.to()方法移动到任何设备(device)上:
前提GPU可用,使用torch.device来将tensor移入和移出GPU
创建一个tensor

x = torch.rand(2,3)
y = torch.rand_like(x) 
print(x)
print(y)

输出结果:

tensor([[0.7374, 0.2935, 0.4500],
        [0.9148, 0.7752, 0.5846]])
tensor([[0.0828, 0.5807, 0.8807],
        [0.9329, 0.8767, 0.0201]])

将tensor放入GPU以便进行加速:

x = x.to('cuda')
print(x)

输出结果:

tensor([[0.7374, 0.2935, 0.4500],
        [0.9148, 0.7752, 0.5846]], device='cuda:0')

我们已经把x放入GPU中了,现在y还在cpu中。如果这个时候进行x和y之间的运算会报错如下:

z = x + y

输出结果:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
译文:RuntimeError:期望所有张量在同一个设备上,但发现至少两个设备,cuda:0和cpu!

所以我们只能对在同一个设备上的数据进行相互之间的操作处理。
将在GPU中的x放入CPU中,x和y在同一个设备中,就可以进行x和y相加。

x = x.to('cpu')
print(x)
z = x +y
print(z)

输出结果:

tensor([[0.7374, 0.2935, 0.4500],
        [0.9148, 0.7752, 0.5846]])
tensor([[0.8202, 0.8742, 1.3308],
        [1.8477, 1.6520, 0.6047]])

以上是关于PyTorch学习Tensor与Numpy数组的相互转换的主要内容,如果未能解决你的问题,请参考以下文章

[Pytorch]Tensor

Pytorch 入门与实战----pytorch入门

《动手学深度学习》PyTorch: 数据操作

PyTorch 内存模型:“torch.from_numpy()”与“torch.Tensor()”

Tensor:Pytorch神经网络界的Numpy

Tensor:Pytorch神经网络界的Numpy