在numpy中标记组内的元素
Posted
技术标签:
【中文标题】在numpy中标记组内的元素【英文标题】:Label elements within a group in numpy 【发布时间】:2022-01-19 07:14:19 【问题描述】:我知道如何标记一个输入数组的元素,如下所示:
arr_value = np.array([0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 1, 1])
arr_res_1 = np.array([0, 1, 2, 3, 3, 4, 5, 6, 7, 7, 8, 9, 9, 9, 9]) # consider zeros in arr_value as elements
arr_res_2 = np.array([0, 1, 0, 2, 2, 0, 3, 0, 4, 4, 5, 6, 6, 6, 6]) # do not consider zeros in arr_value as elements
def shift(arr: np.array, n: int, fill_value=np.nan):
res = np.empty_like(arr)
if n > 0:
res[:n] = fill_value
res[n:] = arr[:-n]
elif n < 0:
res[n:] = fill_value
res[:n] = arr[-n:]
else:
res[:] = arr
return res
def np_label(arr: np.array, replace_zero: bool = True):
arr_shift = shift(arr, 1, fill_value=0)
label = np.where(arr != arr_shift, 1, 0)
if replace_zero:
mask_zero = arr == 0
label[mask_zero] = 0
label = np.cumsum(label)
label[mask_zero] = 0
return label
else:
return np.cumsum(label)
现在,有两个输入数组,包括组和值数组。标签在新组的第一个元素上重置,如果对应的值为0,则为0,否则从1开始。如何不拆分数组或迭代?
arr_group = np.array([0, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 0, 3, 3, 4])
arr_value = np.array([0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 1, 1])
arr_res_1 = np.array([0, 1, 2, 3, 3, 4, 5, 6, 0, 1, 2, 0, 1, 1, 1]) # consider zeros in arr_value as elements
arr_res_2 = np.array([0, 1, 0, 2, 2, 0, 3, 0, 0, 1, 2, 0, 1, 1, 1]) # do not consider zeros in arr_value as elements
【问题讨论】:
【参考方案1】:在计算np.cumsum
之前,您需要找到一种方法来减去每个组的最大索引。 np.add.reduceat
允许您找到这些结果,而无需之前拆分数组。如果您传递将您的组分开的索引,您将获得每个组的总和。
def refresh_groups(label: np.array, mask_group: np.array):
mark_idx = np.flatnonzero(mask_group)
reducer = np.add.reduceat(label, mark_idx)
label[mark_idx[1:]] -= reducer[:-1]
def np_label(arr: np.array, group: np.array, replace_zero: bool = True, replace_group: bool = True):
arr_shift = shift(arr, 1, fill_value=0)
label = np.where(arr != arr_shift, 1, 0)
if replace_zero:
mask_zero = arr == 0
label[mask_zero] = 0
if replace_group:
mask_group = group == 0
refresh_groups(label, mask_group)
label = np.cumsum(label)
if replace_zero:
label[mask_zero] = 0
if replace_group:
label[mask_group] = 0
return label
np_label(arr_value, arr_group, False, False)
只需尝试 replace_zero
和 replace_group
参数的 4 个不同选项来检查它是否与您的预期输出匹配。
【讨论】:
以上是关于在numpy中标记组内的元素的主要内容,如果未能解决你的问题,请参考以下文章