使用 PyTorch 张量将对角线屏蔽为特定值
Posted
技术标签:
【中文标题】使用 PyTorch 张量将对角线屏蔽为特定值【英文标题】:Masking diagonal to a specific value with PyTorch tensors 【发布时间】:2018-09-05 19:53:24 【问题描述】:如何用torch中的值填充对角线?在 numpy 中你可以这样做:
a = np.zeros((3, 3), int)
np.fill_diagonal(a, 5)
array([[5, 0, 0],
[0, 5, 0],
[0, 0, 5]])
我知道torch.diag()
返回对角线,但是如何使用它作为掩码来分配新值是我无法理解的。我无法在此处或 PyTorch 文档中找到答案。
【问题讨论】:
【参考方案1】:一种方法:
>>> import torch
>>> n = 3
>>> t = torch.zeros((n,n))
>>> t[torch.eye(n).byte()] = 5
>>> t
5 0 0
0 5 0
0 0 5
[torch.FloatTensor of size 3x3]
【讨论】:
高版本pytorch最好转用torch.eye(n).bool()。【参考方案2】:您可以在 PyTorch 中使用 fill_diagonal_
执行此操作:
>>> a = torch.zeros(3, 3)
>>> a.fill_diagonal_(5)
tensor([[5, 0, 0],
[0, 5, 0],
[0, 0, 5]])
【讨论】:
以上是关于使用 PyTorch 张量将对角线屏蔽为特定值的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch深度学习实战3-2:什么是张量?Tensor的创建与索引