pytorch中view的使用小结
Posted 非晚非晚
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch中view的使用小结相关的知识,希望对你有一定的参考价值。
pytorch中的view类似于numpy中的resize和reshape,也就是说它的作用是重新组合tensor数据
。
它的流程为:
- 将原始tensor数据
按行展开
。比如[[[1,2,3],[4,5,6]]],展开成一维向量之后就是[1,2,3,4,5,6]。- 然后按照参数组合成
新的tensor
(与原始tensor没有关联
)。
理解一:
不管你原先的数据是[[[1,2,3],[4,5,6]]]还是[1,2,3,4,5,6],因为它们排成一维向量都是6个元素,所以只要view后面的参数一致,得到的结果都是一样的
。举例如下:
import torch
a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=torch.Tensor([1,2,3,4,5,6])
print(a.view(1,6))
print(b.view(1,6))
输出:
tensor([[1., 2., 3., 4., 5., 6.]])
tensor([[1., 2., 3., 4., 5., 6.]])
理解二:
对原始tensor进行view操作之后,它们共享内存
,其中一个tensor改变,会造成另一个tensor改变。所以view只是改变给人看的数据排列形式。
import torch
x = [1,2,3,4,5,6]
a=torch.Tensor(x)
t = a.view(1, 6)
print(a)
print(t)
a[0] = 10 #修改a[0]的值
print(a)
print(t)#会造成a.view改变,所以t也跟着变化。
t[0][1] = 100 #修改t[0][1]的值
print(a)#会造成a的改变。
print(t)
print(t.storage().data_ptr() == a.storage().data_ptr()) #返回True,也就是共享内存
输出:
tensor([1., 2., 3., 4., 5., 6.])
tensor([[1., 2., 3., 4., 5., 6.]])
tensor([10., 2., 3., 4., 5., 6.])
tensor([[10., 2., 3., 4., 5., 6.]])
tensor([ 10., 100., 3., 4., 5., 6.])
tensor([[ 10., 100., 3., 4., 5., 6.]])
True
view的-1参数
view中的参数不能忽略,如果想要让程序自己推到,可以将参数设置为-1。
- 不能被推断的情况:
情况一
:推断有歧义,比如说有6个数据,而设置view(-1,-1,-2),那么就会造成歧义,因为-1的值可以是1和3。情况二
:view只能出现一个-1
。即使人可以推断出来,也不能view的参数也不可以出现2个以上的-1。例如有6个数据,而view(-1,-1,6),这种情况也是不被允许的。
代码举例:
import torch
a=torch.Tensor([[[1,2,3],[4,5,6]]])
print(a.view(-1,6))
print(a.view(-1,3))
print(a.view(-1,-1,6))#不能出现一个以上的-1推断
输出:
tensor([[1., 2., 3., 4., 5., 6.]])
tensor([[1., 2., 3.],
[4., 5., 6.]])
Traceback (most recent call last):
File "../pytorch/test/test.py", line 7, in <module>
print(a.view(-1,-1,6))
RuntimeError: only one dimension can be inferred
以上是关于pytorch中view的使用小结的主要内容,如果未能解决你的问题,请参考以下文章
SwiftUI 如何修复“无法推断复杂的闭包返回类型;添加显式类型以消除歧义”