替换 2D numpy 数组中的连续重复项
Posted
技术标签:
【中文标题】替换 2D numpy 数组中的连续重复项【英文标题】:Replace consecutive duplicates in 2D numpy array 【发布时间】:2022-01-07 22:47:53 【问题描述】:我有一个二维 numpy 数组 x
:
import numpy as np
x = np.array([
[1, 2, 8, 4, 5, 5, 5, 3],
[0, 2, 2, 2, 2, 1, 1, 4]
])
我的目标是用一个特定的值替换所有连续的重复数字(让我们采用-1
),但保持一次出现不变。
我可以这样做:
def replace_consecutive_duplicates(x):
consec_dup = np.zeros(x.shape, dtype=bool)
consec_dup[:, 1:] = np.diff(x, axis=1) == 0
x[consec_dup] = -1
return x
# current output
replace_consecutive_duplicates(x)
# array([[ 1, 2, 8, 4, 5, -1, -1, 3],
# [ 0, 2, -1, -1, -1, 1, -1, 4]])
但是,在这种情况下,保持不变的一个事件始终是第一个。
我的目标是保持中间事件不变。
所以给定相同的 x 作为输入,函数replace_consecutive_duplicates
的期望输出是:
# desired output
replace_consecutive_duplicates(x)
# array([[ 1, 2, 8, 4, -1, 5, -1, 3],
# [ 0, -1, 2, -1, -1, 1, -1, 4]])
请注意,如果出现偶数次的连续重复序列,则左中值应保持不变。所以x[1]
中的连续重复序列[2, 2, 2, 2]
变成[-1, 2, -1, -1]
另外请注意,我正在寻找 2D numpy 数组的矢量化解决方案,因为在我的特定用例中性能是绝对重要的。
我已经尝试过查看运行长度编码和使用np.diff()
之类的东西,但我没有设法解决这个问题。希望大家帮忙!
【问题讨论】:
【参考方案1】:主要问题是您需要连续值的长度。这个用numpy不容易搞定,但是使用itertools.groupby
我们可以用下面的代码解决。
import numpy as np
x = np.array([
[1, 2, 8, 4, 5, 5, 5, 3],
[0, 2, 2, 2, 2, 1, 1, 4]
])
def replace_row(arr: np.ndarray, new_val=-1):
results = []
for val, count in itertools.groupby(arr):
k = len(list(count))
results.extend([new_val] * ((k - 1) // 2))
results.append(val)
results.extend([new_val] * (k // 2))
return np.fromiter(results, arr.dtype)
if __name__ == '__main__':
for idx, row in enumerate(x):
x[idx, :] = replace_row(row)
print(x)
输出:
[[ 1 2 8 4 -1 5 -1 3]
[ 0 -1 2 -1 -1 1 -1 4]]
这不是矢量化的,但可以与多线程结合使用,因为每一行都是逐行处理的。
【讨论】:
以上是关于替换 2D numpy 数组中的连续重复项的主要内容,如果未能解决你的问题,请参考以下文章