对于Pytorch中dim=1的理解
Posted Cai Xukun
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了对于Pytorch中dim=1的理解相关的知识,希望对你有一定的参考价值。
目 录
1 理解shape
对于python中shape的理解:
(1,2) 表示1个一维数组,每个一维数组长度为2;
(1,2,3) 表示1个二维数组,每个二维数组有2个一维数组,每个一维数组长度为3;
(1,2,3,4) 表示1个三维数组,每个三维数组有2个二维数组,每个二维数组有3个一维数组,每个一维数组长度为4。
以下面的tensor为例:
import torch
a = torch.tensor([[[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]]])
print(a.shape)
输出结果为torch.Size([1, 1, 3, 3]),表示1个三维数组,每个三维数组中有1个二维数组,每个二维数组中有3个一维数组,一维数组分别为[1, 2, 3]、[4, 5, 6]、[7, 8, 9],每个一维数组的长度为3。如果我们对tensor进行索引
print(a[0][0][0])
print(a[0][0][1])
print(a[0][0][2])
结果分别为:
tensor([1, 2, 3])
tensor([4, 5, 6])
tensor([7, 8, 9])
要注意,只有一个三维数组,所以第一个索引值只能为0,否则就会报错超出索引值;只有一个二维数组,同理第二个索引值也只能为0;有三个一维数组,第三个索引值可以是0、1、2;每个一维数组长度为3,第四个索引值也可以是0、1、2。
2 理解dim
然后我们再通过torch.mean()函数来理解dim:
a = a.float() # 先转换成float格式,否则会报错
print(torch.mean(a, dim=2))
print(torch.mean(a, dim=3))
运行结果如下:
tensor([[[4., 5., 6.]]])
tensor([[[2., 5., 8.]]])
通过对比可以看出,对于该数组,dim可取的值为0、1、2、3。dim=2意味着在二维数组上进行求平均值的操作,即对一个矩阵按列求平均值;dim=3意味着在一维数组内进行求平均值的操作,即对每个一维数组求平均值。
3 理解模型预测中的dim
再理解在模型预测中遇到的dim,模型输出的数组outputs为(1, 10),即一个一维数组,一维数组中有10个元素,对于该函数:
torch.max(outputs, dim=1)
dim=1即在outputs中的一维数组内取最大值。
以上是关于对于Pytorch中dim=1的理解的主要内容,如果未能解决你的问题,请参考以下文章