如何将张量逐行乘以 PyTorch 中的向量?
Posted
技术标签:
【中文标题】如何将张量逐行乘以 PyTorch 中的向量?【英文标题】:How to multiply a tensor row-wise by a vector in PyTorch? 【发布时间】:2019-05-28 00:18:48 【问题描述】:当我有一个形状为[12, 10]
的张量m
和一个形状为[12]
的标量向量s
时,如何将m
的每一行与s
中的相应标量相乘?
【问题讨论】:
【参考方案1】:您可以将向量广播到更高维张量like so:
def row_mult(t,vector):
extra_dims = (1,)*(t.dim()-1)
return t * vector.view(-1, *extra_dims)
【讨论】:
【参考方案2】:如果您事先知道维数并且可以硬编码正确数量的None
,Shai 的答案就有效。这可以扩展到需要额外的维度:
mask = (torch.rand(12) > 0.5).int()
data = (torch.rand(12, 2, 3, 4))
result = data * mask[:,None,None,None]
result.shape # torch.Size([12, 2, 3, 4])
mask[:,None,None,None].shape # torch.Size([12, 1, 1, 1])
如果您正在处理可变或未知维度的数据,则可能需要手动将mask
扩展为正确的形状
mask = (torch.rand(12) > 0.5).int()
while mask.dim() < data.dim(): mask.unsqueeze_(1)
result = data * mask
result.shape # torch.Size([12, 2, 3, 4])
mask.shape # torch.Size([12, 1, 1, 1])
这是一个有点丑陋的解决方案,但它确实有效。可能有一种更优雅的方法来正确地重塑 mask
内联张量以适应可变数量的维度
【讨论】:
【参考方案3】:需要添加对应的单例维度:
m * s[:, None]
s[:, None]
在将(12, 10)
张量乘以(12, 1)
张量时的大小为(12, 1)
,pytoch 知道broadcast s
沿第二个单例维度并正确执行“元素级”乘积。
【讨论】:
以上是关于如何将张量逐行乘以 PyTorch 中的向量?的主要内容,如果未能解决你的问题,请参考以下文章