Tips pytorch reshape()函数的用法
Posted 海绵_青年
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tips pytorch reshape()函数的用法相关的知识,希望对你有一定的参考价值。
tensor张量定义
import torch
# 多行多列张量
a = torch.Tensor([ [1,2,3,4], [2,3,4,5]])
# 多行一列张量
b = torch.Tensor( [1,2,3,4] )
b_reshaped = torch.Tensor( [1,2,3,4] ).reshape(-1, 1)
# 一行一列张量
c = torch.Tensor([1])
c_reshaped = torch.Tensor([1]).reshape(-1, 1)
# 单元素张量
d = torch.Tensor(1)
结果输出:
a: torch.Size([2, 4])
------------------
b: torch.Size([4])
b_reshaped: torch.Size([4, 1])
------------------
c: torch.Size([1])
c_reshaped: torch.Size([1, 1])
------------------
d: torch.Size([1])
结论
- 是否使用reshape()函数,不影响矩阵相乘
print(torch.matmul(a, b).shape, torch.matmul(a, b_reshaped).shape)
-------------------------
torch.Size([2]) torch.Size([2, 1])
- 是否使用reshape()函数,影响矩阵剪切
print(torch.cat([b_reshaped, b_reshaped], axis=1).shape)
print(torch.cat([b, b], axis=1).shape)
-------------------------
torch.Size([4, 2])
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[15], line 3
1 # 是否使用reshape,影响矩阵剪切
2 print(torch.cat([b_reshaped, b_reshaped], axis=1).shape)
----> 3 print(torch.cat([b, b], axis=1).shape)
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
以上是关于Tips pytorch reshape()函数的用法的主要内容,如果未能解决你的问题,请参考以下文章
pytorch列优先(fortran-like)reshape的实现与性能