对于Pytorch中dim=1的理解

Posted Cai Xukun

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了对于Pytorch中dim=1的理解相关的知识,希望对你有一定的参考价值。

目  录

1 理解shape

2 理解dim

3 理解模型预测中的dim


1 理解shape

对于python中shape的理解:

(1,2) 表示1个一维数组,每个一维数组长度为2;

(1,2,3) 表示1个二维数组,每个二维数组有2个一维数组,每个一维数组长度为3;

(1,2,3,4) 表示1个三维数组,每个三维数组有2个二维数组,每个二维数组有3个一维数组,每个一维数组长度为4。


以下面的tensor为例:

import torch
a = torch.tensor([[[[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]]]])
print(a.shape)

输出结果为torch.Size([1, 1, 3, 3]),表示1个三维数组,每个三维数组中有1个二维数组,每个二维数组中有3个一维数组,一维数组分别为[1, 2, 3]、[4, 5, 6]、[7, 8, 9],每个一维数组的长度为3。如果我们对tensor进行索引

print(a[0][0][0])
print(a[0][0][1])
print(a[0][0][2])

结果分别为:

tensor([1, 2, 3])
tensor([4, 5, 6])
tensor([7, 8, 9])

要注意,只有一个三维数组,所以第一个索引值只能为0,否则就会报错超出索引值;只有一个二维数组,同理第二个索引值也只能为0;有三个一维数组,第三个索引值可以是0、1、2;每个一维数组长度为3,第四个索引值也可以是0、1、2。


2 理解dim

然后我们再通过torch.mean()函数来理解dim:

a = a.float()  # 先转换成float格式,否则会报错
print(torch.mean(a, dim=2))
print(torch.mean(a, dim=3))

运行结果如下:

tensor([[[4., 5., 6.]]])
tensor([[[2., 5., 8.]]])

通过对比可以看出,对于该数组,dim可取的值为0、1、2、3。dim=2意味着在二维数组上进行求平均值的操作,即对一个矩阵按列求平均值;dim=3意味着在一维数组内进行求平均值的操作,即对每个一维数组求平均值。


3 理解模型预测中的dim

再理解在模型预测中遇到的dim,模型输出的数组outputs为(1, 10),即一个一维数组,一维数组中有10个元素,对于该函数:

torch.max(outputs, dim=1)

dim=1即在outputs中的一维数组内取最大值。

以上是关于对于Pytorch中dim=1的理解的主要内容,如果未能解决你的问题,请参考以下文章

pytorch中gather函数的理解。

torch.gather函数的理解

pytorch中的torch.cat()矩阵拼接的用法及理解

nn.Softmax(dim) 的理解

nn.Softmax(dim) 的理解

Pytorch 中的 dim