极智AI | pytorch改变tensor维度的方法

Posted 极智视界

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了极智AI | pytorch改变tensor维度的方法相关的知识,希望对你有一定的参考价值。

欢迎关注我的公众号 [极智视界],获取我的更多笔记分享

大家好,我是极智视界,本文介绍一下 pytorch改变tensor维度的方法

在 pytorch 中,tensor 是基本的操作数据结构。在很多的时候,咱们需要改变 tensor 的维度来适应咱们的计算,包括升维、降维、变维。在 pytorch 中有很多方法可以用来改变 tensor 的维度。

这里我把几种常用的方法进行了一下汇总:

  • view(shape):返回一个新的 tensor,它具有给定的形状。如果元素总数不变,则可以用它来改变 tensor 的维度。例如:
import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
print(t.shape)  # torch.Size([2, 3])

t_view = t.view(3, 2)
print(t_view.shape)  # torch.Size([3, 2])
  • unsqueeze(dim):返回一个新的 tensor,它的指定位置插入了一个新的维度。例如:
import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
print(t.shape)  # torch.Size([2, 3])

t_unsqueeze = t.unsqueeze(0)
print(t_unsqueeze.shape)  # torch.Size([1, 2, 3])

t_unsqueeze = t.unsqueeze(1)
print(t_unsqueeze.shape)  # torch.Size([2, 1, 3])

t_unsqueeze = t.unsqueeze(2)
print(t_unsqueeze.shape)  # torch.Size([2, 3, 1])
  • squeeze(dim):返回一个新的 tensor,它的指定位置的维度的大小为 1 的维度被删除。例如:
import torch

t = torch.tensor([
    [[1], [2], [3]],
    [[4], [5], [6]]
])
print(t.shape)  # torch.Size([2, 3, 1])

t_squeeze = t.squeeze(2)
print(t_squeeze.shape)  # torch.Size([2, 3])

t_squeeze = t.squeeze()
print(t_squeeze.shape)  # torch.Size([2, 3])
  • transpose(dim0, dim1):返回一个新的 tensor,它的排列被交换。例如:
import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
print(t.shape)  # torch.Size([2, 3])

t_transpose = t.transpose(0, 1)
print(t_transpose.shape)  # torch.Size([3, 2])

t_transpose = t.transpose(1, 0)
print(t_transpose.shape)  # torch.Size([3, 2])

另外还有一些其他的方法可以改变tensor的维度,例如 permute() 和 contiguous()。

好了,以上分享了 pytorch中改变tensor维度的方法,希望我的分享能对你的学习有一点帮助。

 

【极智视界】

极智AI | pytorch改变tensor维度的方法


搜索关注我的微信公众号「极智视界」,获取我的更多经验分享,让我们用极致+极客的心态来迎接AI !

以上是关于极智AI | pytorch改变tensor维度的方法的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch中使用 Noneunsqueeze()squeeze() 改变 tensor 的维度

PyTorch中Tensor的维度变换实现

pytorch的Tensor的操作

PyTorch:tensor-张量维度操作(拼接维度扩展压缩转置重复……)

我的NVIDIA开发者之旅 - 极智AI | TensorRT 中 Layer 和 Tensor 的区别

Pytorch中的tensor常用操作