排列后如何进行张量点运算

Posted

技术标签:

【中文标题】排列后如何进行张量点运算【英文标题】:How to make tensordot operations after permutation 【发布时间】:2021-04-19 07:52:42 【问题描述】:

我有 2 个张量,A 和 B:

A = torch.randn([32,128,64,12],dtype=torch.float64)
B = torch.randn([64,12,64,12],dtype=torch.float64)
C = torch.tensordot(A,B,([2,3],[0,1]))
D = C.permute(0,2,1,3) # shape:[32,64,128,12]

张量 D 来自操作“tensordot -> permute”。如何实现一个新的操作 f() 以在 f() 之后进行张量点操作,如:

A_2 = f(A)
B_2 = f(B)
D = torch.tensordot(A_2,B_2)

【问题讨论】:

【参考方案1】:

您是否考虑过使用非常灵活的torch.einsum

D = torch.einsum('ijab,abkl->ikjl', A, B)

tensordot 的问题在于它在B 之前输出A 的所有维度,而您正在寻找(在置换时)是从AB 中“交错”维度。

【讨论】:

是的!我确实最后使用了“torch.einsum”。

以上是关于排列后如何进行张量点运算的主要内容,如果未能解决你的问题,请参考以下文章

Torch:如何按行对张量进行洗牌?

如何在张量流中对张量进行子集化?

张量和向量的区别

如何在 TensorFlow 中选择 2D 张量的某些列?

tensorflow2.0张量的结构操作

Pytorch基础-张量基本操作