numpy.testing.assert_array_equal 失败,两个相同的参差不齐的数组数组
Posted
技术标签:
【中文标题】numpy.testing.assert_array_equal 失败,两个相同的参差不齐的数组数组【英文标题】:numpy.testing.assert_array_equal fails with two identical ragged arrays of arrays 【发布时间】:2021-12-23 20:42:13 【问题描述】:我有两个 numpy 数组,我想测试是否相等。
以下工作正常:
# this works
x = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
y = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
assert np.testing.assert_array_equal(x,y)
但是,如果其中一个内部数组参差不齐,则比较失败:
# this works
x = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
y = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
np.testing.assert_array_equal(x,y)
Traceback (most recent call last):
File "/home/.../test.py", line 12, in <module>
np.testing.assert_array_equal(x,y)
File "/home/.../lib/python3.9/site-packages/numpy/testing/_private/utils.py", line 932, in assert_array_equal
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
File "/home/.../lib/python3.9/site-packages/numpy/testing/_private/utils.py", line 842, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Arrays are not equal
Mismatched elements: 1 / 1 (100%)
x: array([array(['a', 'b'], dtype='<U1'), array(['c'], dtype='<U1')],
dtype=object)
y: array([array(['a', 'b'], dtype='<U1'), array(['c'], dtype='<U1')],
dtype=object)
更新:
为了使故事更加晦涩难懂,以下作品:
x = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
y = x
np.testing.assert_array_equal(x,y)
这是正确的行为吗?
【问题讨论】:
在 2 种情况下显示x==y
(错误时回溯)
有趣的是,如果你设置 x2 = x
然后运行 np.testing.assert_array_equal(x,x2)
它通过了,然后如果你将 x 重新初始化为相同的参差不齐的数组 np.testing.assert_array_equal(x,x2)
失败
你是对的:我会更新问题。这真的很奇怪。
【参考方案1】:
在第一种情况下,数组是 (2,2)(尽管是 object dtype):
In [20]: x = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
...: y = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
In [21]: x
Out[21]:
array([['a', 'b'],
['c', 'd']], dtype=object)
In [22]: x.shape
Out[22]: (2, 2)
In [23]: x==y
Out[23]:
array([[ True, True],
[ True, True]])
断言只需要验证此比较的所有元素是否为真
第二种情况:
In [24]: x = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
...: y = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
In [25]: x
Out[25]:
array([array(['a', 'b'], dtype='<U1'), array(['c'], dtype='<U1')],
dtype=object)
In [26]: x.shape
Out[26]: (2,)
In [27]: x==y
<ipython-input-27-051436df861e>:1: DeprecationWarning: elementwise comparison failed;
this will raise an error in the future.
x==y
Out[27]: False
结果是一个标量,而不是 (2,) 数组。 x==x
生成 True
,并带有相同的警告。
数组元素可以成对比较:
In [30]: [i==j for i,j in zip(x,y)]
Out[30]: [array([ True, True]), array([ True])]
【讨论】:
以上是关于numpy.testing.assert_array_equal 失败,两个相同的参差不齐的数组数组的主要内容,如果未能解决你的问题,请参考以下文章