PyTorch 张量广播
Posted
技术标签:
【中文标题】PyTorch 张量广播【英文标题】:PyTorch Tensor broadcasting 【发布时间】:2021-12-31 10:11:14 【问题描述】:我正在尝试弄清楚如何进行以下广播:
我有两个张量,大小分别为 (n1,N) 和 (n2,N)
我想要做的是将第一个张量的每一行与第二个张量的每一行相乘,然后将每个相乘的行结果相加,这样我的最终张量应该是 (n1,n2 )。
我试过了:
x1*torch.reshape(x2,(x2.size(dim=0),x2.size(dim=1),1))
但显然这不起作用..无法弄清楚如何做到这一点
【问题讨论】:
【参考方案1】:您要查找的是来自PyTorch 和Numpy 的Tensordot
命令
由于您想计算沿 N 的点积,即 x1
的维度 1 和 x2
张量的维度 1,您需要通过提供 ([1], [1])
沿两个张量的第一轴执行收缩到 Tensordot 中的 dims
arg。这意味着 Torch 将分别在指定的 x1 轴 1
和指定的 x2 轴 1
上求和 x1
和 x2
元素的乘积。提供给dims
的参数很混乱,这里有一个有用的线程来帮助理解如何使用Tensordot
here
x1 = torch.arange(6.).reshape(2,3)
>>> tensor([[0., 1., 2.],
[3., 4., 5.]])
# x1 is Tensor of shape (2,3)
x2 = torch.arange(9.).reshape(3,3)
>>> tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]])
# x2 is Tensor of shape (3,3)
x = torch.tensordot(x1, x2, dims=([1],[1]))
>>> tensor([[ 5., 14., 23.],
[14., 50., 86.]])
# x is Tensor of shape (2,3)
【讨论】:
【参考方案2】:您所描述的似乎实际上与在第一个张量和第二个张量的转置之间执行矩阵乘法相同。这可以这样做:
torch.matmul(x1, x2.T)
【讨论】:
以上是关于PyTorch 张量广播的主要内容,如果未能解决你的问题,请参考以下文章
[PyTroch系列-11]:PyTorch基础 - 张量Tensor元素的排序
[PyTroch系列-7]:PyTorch基础 - 张量Tensor的算术运算