将两个字典与 numpy 矩阵作为值进行比较

Posted

技术标签:

【中文标题】将两个字典与 numpy 矩阵作为值进行比较【英文标题】:Comparing two dictionaries with numpy matrices as values 【发布时间】:2014-12-12 19:04:14 【问题描述】:

我想断言两个 Python 字典是相等的(这意味着:键的数量相等,并且从键到值的每个映射都是相等的;顺序并不重要)。一个简单的方法是assert A==B,但是,如果字典的值为numpy arrays,这将不起作用。如何编写一个函数来检查两个字典是否相等?

>>> import numpy as np
>>> A = 1: np.identity(5)
>>> B = 1: np.identity(5) + np.ones([5,5])
>>> A == B
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

编辑 我知道 numpy 矩阵应检查是否与 .all() 相等。我正在寻找的是一种检查此问题的通用方法,而无需检查isinstance(np.ndarray)。这可能吗?

没有numpy数组的相关主题:

Comparing two dictionaries in Python Comparing/combining two dictionaries

【问题讨论】:

【参考方案1】:

你可以使用numpy.testing.assert_equal

http://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_equal.html

【讨论】:

这不返回布尔值。相反,如果对象不相等,它会引发异常。它可以用来构造一个 cmp 函数,但它本身不是一个。 @GuilhermedeLazari 在某些时候你在这里分裂头发。只需使用 try/except 块创建您的 cmp 函数。它几乎是自己写的。 @GuilhermedeLazari 最初的问题是“我想断言两个 Python 字典是相等的” 只有当您预先知道这些值是 numpy 数组时,此答案才有效。问题是在不首先检查值的实例类型的情况下找到一种通用方法。 我已经有一段时间没有使用它了,但是文档说“给定两个对象(标量、列表、元组、字典或 numpy 数组),检查这些对象的所有元素是否相等。在第一个冲突值处引发异常。”似乎它应该适用于其他类型,但如果你发现这不是真的,那可能是一个错误。【参考方案2】:

我将回答隐藏在你的问题标题和前半部分中的一半问题,因为坦率地说,这是一个更常见的问题需要解决,现有的答案并不能很好地解决它。这个问题是“如何比较两个 numpy 数组的字典是否相等”?

问题的第一部分是“从远处”检查字典:查看它们的键是否相同。如果所有键都相同,则第二部分是比较每个对应的值。

现在微妙的问题是很多 numpy 数组不是整数值的,double-precision is imprecise。因此,除非您有整数值(或其他非浮点型)数组,否则您可能需要检查这些值是否几乎相同,即在机器精度范围内。所以在这种情况下,您不会使用np.array_equal(检查精确的数值相等性),而是使用np.allclose(对两个数组之间的相对和绝对误差使用有限容差)。

问题的前半部分很简单:检查字典的键是否一致,并使用生成器推导来比较每个值(并在推导之外使用all 来验证每个项目是同样):

import numpy as np

# some dummy data

# these are equal exactly
dct1 = 'a': np.array([2, 3, 4])
dct2 = 'a': np.array([2, 3, 4])

# these are equal _roughly_
dct3 = 'b': np.array([42.0, 0.2])
dct4 = 'b': np.array([42.0, 3*0.1 - 0.1])  # still 0.2, right?

def compare_exact(first, second):
    """Return whether two dicts of arrays are exactly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.array_equal(first[key], second[key]) for key in first)

def compare_approximate(first, second):
    """Return whether two dicts of arrays are roughly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.allclose(first[key], second[key]) for key in first)

# let's try them:
print(compare_exact(dct1, dct2))  # True
print(compare_exact(dct3, dct4))  # False
print(compare_approximate(dct3, dct4))  # True

正如您在上面的示例中所看到的,整数数组比较准确,并且取决于您正在做什么(或者如果您很幸运),它甚至可以用于浮点数。但是,如果您的浮点数是任何算术的结果(例如线性变换?),您绝对应该使用近似检查。有关后一个选项的完整描述,请参阅the docs of numpy.allclose(及其元素朋友numpy.isclose),特别注意rtolatol 关键字参数。

【讨论】:

【参考方案3】:

您可以分离两个字典的键和值,并比较键与键以及值与值: 这是解决方案

import numpy as np

def dic_to_keys_values(dic):
    keys, values = list(dic.keys()), list(dic.values())
    return keys, values

def numpy_assert_almost_dict_values(dict1, dict2):
    keys1, values1 = dic_to_keys_values(dict1)
    keys2, values2 = dic_to_keys_values(dict2)
    np.testing.assert_equal(keys1, keys2)
    np.testing.assert_almost_equal(values1, values2)

dict1 = "b": np.array([1, 2, 0.2])
dict2 = "b": np.array([1, 2, 3 * 0.1 - 0.1])  # almost 0.2, but not equal
dict3 = "b": np.array([999, 888, 444]) # completely different

numpy_assert_almost_dict_values(dict1, dict2) # no exception because almost equal
# numpy_assert_almost_dict_values(dict1, dict3) # exception because not equal

(注意,上面检查了精确的键和几乎相等的值)

【讨论】:

【参考方案4】:

考虑这段代码

>>> import numpy as np
>>> np.identity(5)
array([[ 1.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  1.]])
>>> np.identity(5)+np.ones([5,5])
array([[ 2.,  1.,  1.,  1.,  1.],
       [ 1.,  2.,  1.,  1.,  1.],
       [ 1.,  1.,  2.,  1.,  1.],
       [ 1.,  1.,  1.,  2.,  1.],
       [ 1.,  1.,  1.,  1.,  2.]])
>>> np.identity(5) == np.identity(5)+np.ones([5,5])
array([[False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False]], dtype=bool)
>>> 

注意比较的结果是一个矩阵,而不是一个布尔值。 dict比较将使用values cmp方法比较值,这意味着在比较矩阵值时,dict比较会得到一个复合结果。你想要做的是使用 numpy.all 将复合数组结果折叠成标量布尔结果

>>> np.all(np.identity(5) == np.identity(5)+np.ones([5,5]))
False
>>> np.all(np.identity(5) == np.identity(5))
True
>>> 

您需要编写自己的函数来比较这些字典,测试值类型以查看它们是否为矩阵,然后使用numpy.all 进行比较,否则使用==。当然,如果您愿意,您也可以随时花哨并开始子类化 dict 和重载 cmp

【讨论】:

我对此不是很清楚,但我希望有一种通用的方法,而不需要明确检查类型。今天它是一个numpy数组,明天它是我今天从未听说过的类型。 恐怕没有办法绕过它。如果您(或 numpy 或其他人的)类型覆盖 cmp 以返回非标量,则标准 python 比较将无法处理它。 您不需要编写自己的函数,因为 numpy 已经涵盖了您。请参阅 vitral 的回答。

以上是关于将两个字典与 numpy 矩阵作为值进行比较的主要内容,如果未能解决你的问题,请参考以下文章

将字典列表转换为 numpy 矩阵? [复制]

如何在 C# 中比较两个字典

将字典列表中存在的多个 id 与模型 django 的 id 进行比较

将列表列表与字典进行比较,并将输出作为元组列表的列表

将字典的键与 List 的值进行比较,并返回所有匹配的值,包括重复值

我想将国家/地区列表与作为熊猫数据框 Python 中字典对象类型的列数据进行比较