pytorch常用张量操作

Posted sunupo

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch常用张量操作相关的知识,希望对你有一定的参考价值。

1.torch.sum

z =  torch.arange(40.).reshape(2, 4, 5)
print(z)
print(torch.sum( z,0))
print(torch.sum( z,1))
print(torch.sum( z,2))
tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.]],

        [[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.]]])
tensor([[20., 22., 24., 26., 28.],
        [30., 32., 34., 36., 38.],
        [40., 42., 44., 46., 48.],
        [50., 52., 54., 56., 58.]])
tensor([[ 30.,  34.,  38.,  42.,  46.],
        [110., 114., 118., 122., 126.]])
tensor([[ 10.,  35.,  60.,  85.],
        [110., 135., 160., 185.]])

2.torch.tensordot

    a = torch.arange(60.).reshape(3, 4, 5)
    b = torch.arange(24.).reshape(4, 3, 2)
    print(a,b)
    print(torch.tensordot(a, b, dims=([1, 0], [0, 1])))
tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.]],

        [[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.]],

        [[40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.],
         [50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.]]]) 
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.],
         [16., 17.]],

        [[18., 19.],
         [20., 21.],
         [22., 23.]]])
tensor([[4400., 4730.],
        [4532., 4874.],
        [4664., 5018.],
        [4796., 5162.],
        [4928., 5306.]])

 

    x_l = torch.arange(15.).reshape(5, 3, 1)
    kernels = torch.arange(3.).reshape( 3, 1)
    xl_w = torch.tensordot(x_l, kernels, dims=([1], [0]))
    d = torch.matmul(x_l, xl_w)
    print(x_l)
    print(kernels)
    print(xl_w)
    print(d)
tensor([[[ 0.],
         [ 1.],
         [ 2.]],

        [[ 3.],
         [ 4.],
         [ 5.]],

        [[ 6.],
         [ 7.],
         [ 8.]],

        [[ 9.],
         [10.],
         [11.]],

        [[12.],
         [13.],
         [14.]]])
tensor([[0.],
        [1.],
        [2.]])
tensor([[[ 5.]],

        [[14.]],

        [[23.]],

        [[32.]],

        [[41.]]])
tensor([[[  0.],
         [  5.],
         [ 10.]],

        [[ 42.],
         [ 56.],
         [ 70.]],

        [[138.],
         [161.],
         [184.]],

        [[288.],
         [320.],
         [352.]],

        [[492.],
         [533.],
         [574.]]])

 

以上是关于pytorch常用张量操作的主要内容,如果未能解决你的问题,请参考以下文章

pytorch常用张量操作

AI常用框架和工具丨12. 深度学习框架PyTorch

AI常用框架和工具丨12. 深度学习框架PyTorch

AI常用框架和工具丨12. 深度学习框架PyTorch

PyTorch常用代码段合集

pytorch 中的常用矩阵操作