如何检查两个 Torch 张量或矩阵是不是相等?
Posted
技术标签:
【中文标题】如何检查两个 Torch 张量或矩阵是不是相等?【英文标题】:How to check if two Torch tensors or matrices are equal?如何检查两个 Torch 张量或矩阵是否相等? 【发布时间】:2016-01-04 22:10:41 【问题描述】:我需要一个 Torch 命令来检查两个张量是否具有相同的内容,如果它们具有相同的内容则返回 TRUE。
例如:
local tens_a = torch.Tensor(9,8,7,6);
local tens_b = torch.Tensor(9,8,7,6);
if (tens_a EQUIVALENCE_COMMAND tens_b) then ... end
我应该在这个脚本中使用什么来代替EQUIVALENCE_COMMAND
?
我只是尝试了==
,但它不起作用。
【问题讨论】:
要考虑浮点差异,请参阅Check if PyTorch tensors are equal within epsilon。 【参考方案1】:torch.eq(a, b)
eq()
实现==
运算符比较a
中的每个元素与b
(如果b 是一个值)或a
中的每个元素与其对应的b
中的元素(如果b
是张量)。
@deltheil 的替代方案:
torch.all(tens_a.eq(tens_b))
【讨论】:
正如其他答案中提到的,使用当前的火炬,.eq
返回一个张量,而.equal
实际上返回一个布尔值。【参考方案2】:
以下解决方案对我有用:
torch.equal(tensorA, tensorB)
来自the documentation:
True
如果两个张量具有相同的大小和元素,则False
否则。
【讨论】:
这个答案应该是唯一用于这个问题的答案,因为这个函数与 OP 想要的确切行为匹配 + 它是最有效的,如果张量的形状不同,则不进行计算。 【参考方案3】:要比较张量,您可以按元素进行:
torch.eq
是元素方面的:
torch.eq(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
tensor([[True, False], [False, True]])
或者torch.equal
正好代表整个张量:
torch.equal(torch.tensor([[1., 2.], [3, 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
# False
torch.equal(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.], [3., 4.]]))
# True
但是你可能会迷失方向,因为在某些时候你想忽略一些小的差异。例如 floats 1.0
和 1.0000000001
非常接近,您可能会认为它们是相等的。对于这种比较,你有torch.allclose
。
torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))
# True
在某些时候,与元素的全部数量相比,检查元素数量是否相等可能很重要。如果你有两个张量 dt1
和 dt2
你会得到 dt1
的元素数量为 dt1.nelement()
通过这个公式,您可以得到百分比:
print(torch.sum(torch.eq(dt1, dt2)).item()/dt1.nelement())
【讨论】:
torch.allclose() 是我要找的那个。 不等于怎么办?【参考方案4】:如果您想忽略浮点数常见的细微精度差异,请尝试此操作
torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-12))
【讨论】:
您也可以使用torch.allclose()
。【参考方案5】:
您可以将两个张量转换为 numpy 数组:
local tens_a = torch.Tensor((9,8,7,6));
local tens_b = torch.Tensor((9,8,7,6));
a=tens_a.numpy()
b=tens_b.numpy()
然后是类似的东西
np.sum(a==b)
4
会让您对他们的平等程度有一个相当好的了解。
【讨论】:
以上是关于如何检查两个 Torch 张量或矩阵是不是相等?的主要内容,如果未能解决你的问题,请参考以下文章