如何检查两个 Torch 张量或矩阵是不是相等?

Posted

技术标签:

【中文标题】如何检查两个 Torch 张量或矩阵是不是相等?【英文标题】:How to check if two Torch tensors or matrices are equal?如何检查两个 Torch 张量或矩阵是否相等? 【发布时间】:2016-01-04 22:10:41 【问题描述】:

我需要一个 Torch 命令来检查两个张量是否具有相同的内容,如果它们具有相同的内容则返回 TRUE。

例如:

local tens_a = torch.Tensor(9,8,7,6);
local tens_b = torch.Tensor(9,8,7,6);

if (tens_a EQUIVALENCE_COMMAND tens_b) then ... end

我应该在这个脚本中使用什么来代替EQUIVALENCE_COMMAND

我只是尝试了==,但它不起作用。

【问题讨论】:

要考虑浮点差异,请参阅Check if PyTorch tensors are equal within epsilon。 【参考方案1】:
torch.eq(a, b)

eq() 实现== 运算符比较a 中的每个元素与b(如果b 是一个值)或a 中的每个元素与其对应的b 中的元素(如果b是张量)。


@deltheil 的替代方案:

torch.all(tens_a.eq(tens_b))

【讨论】:

正如其他答案中提到的,使用当前的火炬,.eq 返回一个张量,而.equal 实际上返回一个布尔值。【参考方案2】:

以下解决方案对我有用:

torch.equal(tensorA, tensorB)

来自the documentation:

True 如果两个张量具有相同的大小和元素,则False 否则。

【讨论】:

这个答案应该是唯一用于这个问题的答案,因为这个函数与 OP 想要的确切行为匹配 + 它是最有效的,如果张量的形状不同,则不进行计算。 【参考方案3】:

要比较张量,您可以按元素进行:

torch.eq 是元素方面的:

torch.eq(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
tensor([[True, False], [False, True]])

或者torch.equal 正好代表整个张量:

torch.equal(torch.tensor([[1., 2.], [3, 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
# False
torch.equal(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.], [3., 4.]]))
# True

但是你可能会迷失方向,因为在某些时候你想忽略一些小的差异。例如 floats 1.01.0000000001 非常接近,您可能会认为它们是相等的。对于这种比较,你有torch.allclose

torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))
# True

在某些时候,与元素的全部数量相比,检查元素数量是否相等可能很重要。如果你有两个张量 dt1dt2 你会得到 dt1 的元素数量为 dt1.nelement()

通过这个公式,您可以得到百分比:

print(torch.sum(torch.eq(dt1, dt2)).item()/dt1.nelement())

【讨论】:

torch.allclose() 是我要找的那个。 不等于怎么办?【参考方案4】:

如果您想忽略浮点数常见的细微精度差异,请尝试此操作

torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-12))

【讨论】:

您也可以使用torch.allclose()【参考方案5】:

您可以将两个张量转换为 numpy 数组:

local tens_a = torch.Tensor((9,8,7,6));
local tens_b = torch.Tensor((9,8,7,6));

a=tens_a.numpy()
b=tens_b.numpy()

然后是类似的东西

np.sum(a==b)
4

会让您对他们的平等程度有一个相当好的了解。

【讨论】:

以上是关于如何检查两个 Torch 张量或矩阵是不是相等?的主要内容,如果未能解决你的问题,请参考以下文章

如何对两个 PyTorch 量化张量进行矩阵相乘?

如何检查不同张量pytorch中的张量值是不是?

如何在 Pytorch 中检查张量是不是在 cuda 上?

如何在 Torch 的网络开头合并两个张量?

如何有效地检索 Torch 张量中最大值的索引?

如何检查存储过程中两个 SELECT 语句的输出是不是相等、大于或小于?