Tips pytorch reshape()函数的用法

Posted 海绵_青年

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tips pytorch reshape()函数的用法相关的知识,希望对你有一定的参考价值。

tensor张量定义

import torch

# 多行多列张量
a = torch.Tensor([ [1,2,3,4], [2,3,4,5]])

# 多行一列张量
b = torch.Tensor( [1,2,3,4] )
b_reshaped = torch.Tensor( [1,2,3,4] ).reshape(-1, 1)

# 一行一列张量
c = torch.Tensor([1])
c_reshaped = torch.Tensor([1]).reshape(-1, 1)

# 单元素张量
d = torch.Tensor(1)

结果输出:

a: torch.Size([2, 4])
------------------
b: torch.Size([4])
b_reshaped: torch.Size([4, 1])
------------------
c: torch.Size([1])
c_reshaped: torch.Size([1, 1])
------------------
d: torch.Size([1])

结论

  1. 是否使用reshape()函数,不影响矩阵相乘
print(torch.matmul(a, b).shape, torch.matmul(a, b_reshaped).shape)
-------------------------
torch.Size([2]) torch.Size([2, 1])
  1. 是否使用reshape()函数,影响矩阵剪切
print(torch.cat([b_reshaped, b_reshaped], axis=1).shape)
print(torch.cat([b, b], axis=1).shape)
-------------------------
torch.Size([4, 2])
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[15], line 3
      1 # 是否使用reshape,影响矩阵剪切
      2 print(torch.cat([b_reshaped, b_reshaped], axis=1).shape)
----> 3 print(torch.cat([b, b], axis=1).shape)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

以上是关于Tips pytorch reshape()函数的用法的主要内容,如果未能解决你的问题,请参考以下文章

pytorch列优先(fortran-like)reshape的实现与性能

pytorch中gather函数的理解。

pytorch 常用函数参数详解

pytorch 函数理解

pytorch之transforms.Compose()函数理解

pytorch常用normalization函数