numpy.testing.assert_array_equal 失败,两个相同的参差不齐的数组数组

Posted

技术标签:

【中文标题】numpy.testing.assert_array_equal 失败,两个相同的参差不齐的数组数组【英文标题】:numpy.testing.assert_array_equal fails with two identical ragged arrays of arrays 【发布时间】:2021-12-23 20:42:13 【问题描述】:

我有两个 numpy 数组,我想测试是否相等。

以下工作正常:

# this works
x = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
y = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
assert np.testing.assert_array_equal(x,y)

但是,如果其中一个内部数组参差不齐,则比较失败:

# this works
x = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
y = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
np.testing.assert_array_equal(x,y)

Traceback (most recent call last):
  File "/home/.../test.py", line 12, in <module>
    np.testing.assert_array_equal(x,y)
  File "/home/.../lib/python3.9/site-packages/numpy/testing/_private/utils.py", line 932, in assert_array_equal
    assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
  File "/home/.../lib/python3.9/site-packages/numpy/testing/_private/utils.py", line 842, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal

Mismatched elements: 1 / 1 (100%)
 x: array([array(['a', 'b'], dtype='<U1'), array(['c'], dtype='<U1')],
      dtype=object)
 y: array([array(['a', 'b'], dtype='<U1'), array(['c'], dtype='<U1')],
      dtype=object)

更新:

为了使故事更加晦涩难懂,以下作品:

x = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
y = x
np.testing.assert_array_equal(x,y)

这是正确的行为吗?

【问题讨论】:

在 2 种情况下显示 x==y(错误时回溯) 有趣的是,如果你设置 x2 = x 然后运行 ​​np.testing.assert_array_equal(x,x2) 它通过了,然后如果你将 x 重新初始化为相同的参差不齐的数组 np.testing.assert_array_equal(x,x2) 失败 你是对的:我会更新问题。这真的很奇怪。 【参考方案1】:

在第一种情况下,数组是 (2,2)(尽管是 object dtype):

In [20]: x = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
    ...: y = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
In [21]: x
Out[21]: 
array([['a', 'b'],
       ['c', 'd']], dtype=object)
In [22]: x.shape
Out[22]: (2, 2)
In [23]: x==y
Out[23]: 
array([[ True,  True],
       [ True,  True]])

断言只需要验证此比较的所有元素是否为真

第二种情况:

In [24]: x = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
    ...: y = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
In [25]: x
Out[25]: 
array([array(['a', 'b'], dtype='<U1'), array(['c'], dtype='<U1')],
      dtype=object)
In [26]: x.shape
Out[26]: (2,)
In [27]: x==y
<ipython-input-27-051436df861e>:1: DeprecationWarning: elementwise comparison failed; 
 this will raise an error in the future.
  x==y
Out[27]: False

结果是一个标量,而不是 (2,) 数组。 x==x 生成 True,并带有相同的警告。

数组元素可以成对比较:

In [30]: [i==j for i,j in zip(x,y)]
Out[30]: [array([ True,  True]), array([ True])]

【讨论】:

以上是关于numpy.testing.assert_array_equal 失败,两个相同的参差不齐的数组数组的主要内容,如果未能解决你的问题,请参考以下文章