PyTorch 中的 tensordot 以及 einsum 函数介绍
Posted 珍妮的选择
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch 中的 tensordot 以及 einsum 函数介绍相关的知识,希望对你有一定的参考价值。
PyTorch 中的 tensordot 以及 einsum 函数介绍
文章目录
前言
最近发现这两个函数用得越来越频繁, 比如在 DCN 网络的实现中就用到了(详见 Deep Cross Network (深度交叉网络, DCN) 介绍与代码分析), 但是过段时间又忘记这两个函数到底实现啥功能, 趁着现在印象还比较深刻的时候记录一下 ????????????.
广而告之
可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号;另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中;
从例子出发
拿一大串中文或英文来形容这两个函数, 该懵逼的还是懵逼, 从例子出发可以很容易理解它们的具体功能. 例子来自 Stackoverflow: product-of-pytorch-tensors-along-arbitrary-axes.
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0], [0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)
从例子中可以发现, einsum
同样可以实现 tensordot
的功能. 但现在的问题是, c
等于
[[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
这个结果具体是怎么得到的 ?
在 numpy.tensordot 文档中, 对 tensordot
的功能解释为:
Compute tensor dot product along specified axes
上面代码中, 在计算 c
时, 指定了 axes
:
c = np.tensordot(a, b, axes=([1,0], [0,1]))
其中 a
用来参与计算的轴为 [1, 0]
, 由于 a.shape = (3, 4, 3)
, 那么用来参与计算的子数组 A
大小为 (4, 3)
;
对于 b
来说, 用来参与计算的轴为 [0, 1]
, 由于 b.shape = (4, 3, 2)
, 那么用来参与计算的子数组 B
大小为 (4, 3)
;
最后进行子数组(tensor)间的 dot product, 即 sum(A * B)
, 得到一个 scalar, 注意 *
是 element-wise 的乘法, 而不是矩阵乘法. 经过 tensordot
后, a
还保留着第 3 个维度, 大小为 a.shape[2] = 3
, 而 b
也保留着第 3 个维度, 大小为 b.shape[2] = 2
, 此时 c
的大小为 (a.shape[2], b.shape[2]) = (3, 2)
.
经过以上分析, 我们现在换种思路来计算 c
, 代码如下:
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
aa = a.transpose((2, 1, 0)) ## aa.shape = (3, 4, 3)
bb = b.transpose((2, 0, 1)) ## bb.shape = (2, 4, 3)
print(np.sum(aa[0], bb[0]))
# 2640.0
cc = [
[np.sum(aa[0] * bb[0]), np.sum(aa[0] * bb[1])],
[np.sum(aa[1] * bb[0]), np.sum(aa[1] * bb[1])],
[np.sum(aa[2] * bb[0]), np.sum(aa[2] * bb[1])]
]
print(cc)
[[2640.0, 2838.0], [2772.0, 2982.0], [2904.0, 3126.0]]
因此, tensordot 的作用是将 axes
指定的子数组进行点乘, axes
指定具体的维度.
einsum 的用法非常丰富, 下面参考资料中的例子无不显示着这个函数的强大:
经过上面的分析, 可以发现 enisum
可以完成 tensordot
的功能, 即:
c = torch.einsum("ijk,jil->kl", (a, b))
用指定的字符串 "ijk,jil->kl"
就能形象地说明运算的目的.
灵魂画手
再说明一下 transpose. 它和 reshape
不一样, 我觉得它是改变观看 tensor 的视角. 比如对于如下矩阵:
(需要一点空间想象 ????????????)
参考资料
- Stackoverflow: understanding-pytorch-einsum 例子相当丰富
- einsum满足你一切需要:深度学习中的爱因斯坦求和约定 我只是想学习下中文表达~
- numpy.tensordot Numpy 的 tensordot 文档
- Stackoverflow: understanding-tensordot 解释的很通俗
- Stackoverflow: product-of-pytorch-tensors-along-arbitrary-axes 本文第一个例子的出处
以上是关于PyTorch 中的 tensordot 以及 einsum 函数介绍的主要内容,如果未能解决你的问题,请参考以下文章
pytorch中的数据加载(dataset基类,以及pytorch自带数据集)
markdown PyTorch中的双线性插值,以及基准测试与numpy
PyTorch中的Stack和Cat以及Tensorflow和Numpy的区别