PyTorch 中的 tensordot 以及 einsum 函数介绍

Posted 珍妮的选择

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch 中的 tensordot 以及 einsum 函数介绍相关的知识,希望对你有一定的参考价值。


PyTorch 中的 tensordot 以及 einsum 函数介绍

文章目录

前言

最近发现这两个函数用得越来越频繁, 比如在 DCN 网络的实现中就用到了(详见 ​​Deep Cross Network (深度交叉网络, DCN) 介绍与代码分析​​), 但是过段时间又忘记这两个函数到底实现啥功能, 趁着现在印象还比较深刻的时候记录一下 ????????????.

广而告之

可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号;另外可以看看知乎专栏 ​​PoorMemory-机器学习​​, 以后文章也会发在知乎专栏中;

从例子出发

拿一大串中文或英文来形容这两个函数, 该懵逼的还是懵逼, 从例子出发可以很容易理解它们的具体功能. 例子来自 ​​Stackoverflow: product-of-pytorch-tensors-along-arbitrary-axes​​.

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0], [0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)

从例子中可以发现, ​​einsum​​​ 同样可以实现 ​​tensordot​​​ 的功能. 但现在的问题是, ​​c​​ 等于

[[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

这个结果具体是怎么得到的 ?

在 ​​numpy.tensordot​​​ 文档中, 对 ​​tensordot​​ 的功能解释为:

Compute tensor dot product along specified axes

上面代码中, 在计算 ​​c​​​ 时, 指定了 ​​axes​​:

c = np.tensordot(a, b, axes=([1,0], [0,1]))

其中 ​​a​​​ 用来参与计算的轴为 ​​[1, 0]​​​, 由于 ​​a.shape = (3, 4, 3)​​​, 那么用来参与计算的子数组 ​​A​​​ 大小为 ​​(4, 3)​​​;
对于 ​​​b​​​ 来说, 用来参与计算的轴为 ​​[0, 1]​​​, 由于 ​​b.shape = (4, 3, 2)​​​, 那么用来参与计算的子数组 ​​B​​​ 大小为 ​​(4, 3)​​​;
最后进行子数组(tensor)间的 dot product, 即 ​​​sum(A * B)​​​, 得到一个 scalar, 注意 ​​*​​​ 是 element-wise 的乘法, 而不是矩阵乘法. 经过 ​​tensordot​​​ 后, ​​a​​​ 还保留着第 3 个维度, 大小为 ​​a.shape[2] = 3​​​, 而 ​​b​​​ 也保留着第 3 个维度, 大小为 ​​b.shape[2] = 2​​​, 此时 ​​c​​​ 的大小为 ​​(a.shape[2], b.shape[2]) = (3, 2)​​.

经过以上分析, 我们现在换种思路来计算 ​​c​​, 代码如下:

import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
aa = a.transpose((2, 1, 0)) ## aa.shape = (3, 4, 3)
bb = b.transpose((2, 0, 1)) ## bb.shape = (2, 4, 3)
print(np.sum(aa[0], bb[0]))
# 2640.0
cc = [
[np.sum(aa[0] * bb[0]), np.sum(aa[0] * bb[1])],
[np.sum(aa[1] * bb[0]), np.sum(aa[1] * bb[1])],
[np.sum(aa[2] * bb[0]), np.sum(aa[2] * bb[1])]
]
print(cc)
[[2640.0, 2838.0], [2772.0, 2982.0], [2904.0, 3126.0]]

因此, tensordot 的作用是将 ​​axes​​​ 指定的子数组进行点乘, ​​axes​​ 指定具体的维度.

einsum 的用法非常丰富, 下面参考资料中的例子无不显示着这个函数的强大:

经过上面的分析, 可以发现 ​​enisum​​​ 可以完成 ​​tensordot​​ 的功能, 即:

c = torch.einsum("ijk,jil->kl", (a, b))

用指定的字符串 ​​"ijk,jil->kl"​​ 就能形象地说明运算的目的.

灵魂画手

再说明一下 transpose. 它和 ​​reshape​​ 不一样, 我觉得它是改变观看 tensor 的视角. 比如对于如下矩阵:

(需要一点空间想象 ????????????)

PyTorch

参考资料


以上是关于PyTorch 中的 tensordot 以及 einsum 函数介绍的主要内容,如果未能解决你的问题,请参考以下文章

pytorch:What is PyTorch?

pytorch中的数据加载(dataset基类,以及pytorch自带数据集)

markdown PyTorch中的双线性插值,以及基准测试与numpy

PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别

PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别

PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别