在numpy中标记组内的元素

Posted

技术标签:

【中文标题】在numpy中标记组内的元素【英文标题】:Label elements within a group in numpy 【发布时间】:2022-01-19 07:14:19 【问题描述】:

我知道如何标记一个输入数组的元素,如下所示:

arr_value = np.array([0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 1, 1])
arr_res_1 = np.array([0, 1, 2, 3, 3, 4, 5, 6, 7, 7, 8, 9, 9, 9, 9])  # consider zeros in arr_value as elements
arr_res_2 = np.array([0, 1, 0, 2, 2, 0, 3, 0, 4, 4, 5, 6, 6, 6, 6])  # do not consider zeros in arr_value as elements

def shift(arr: np.array, n: int, fill_value=np.nan):
    res = np.empty_like(arr)
    if n > 0:
        res[:n] = fill_value
        res[n:] = arr[:-n]
    elif n < 0:
        res[n:] = fill_value
        res[:n] = arr[-n:]
    else:
        res[:] = arr
    return res

def np_label(arr: np.array, replace_zero: bool = True):
    arr_shift = shift(arr, 1, fill_value=0)
    label = np.where(arr != arr_shift, 1, 0)
    if replace_zero:
        mask_zero = arr == 0
        label[mask_zero] = 0
        label = np.cumsum(label)
        label[mask_zero] = 0
        return label
    else:
        return np.cumsum(label)

现在,有两个输入数组,包括组和值数组。标签在新组的第一个元素上重置,如果对应的值为0,则为0,否则从1开始。如何不拆分数组或迭代?

arr_group = np.array([0, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 0, 3, 3, 4])
arr_value = np.array([0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 1, 1])
arr_res_1 = np.array([0, 1, 2, 3, 3, 4, 5, 6, 0, 1, 2, 0, 1, 1, 1])  # consider zeros in arr_value as elements
arr_res_2 = np.array([0, 1, 0, 2, 2, 0, 3, 0, 0, 1, 2, 0, 1, 1, 1])  # do not consider zeros in arr_value as elements

【问题讨论】:

【参考方案1】:

在计算np.cumsum 之前,您需要找到一种方法来减去每个组的最大索引。 np.add.reduceat 允许您找到这些结果,而无需之前拆分数组。如果您传递将您的组分开的索引,您将获得每个组的总和。

def refresh_groups(label: np.array, mask_group: np.array):
    mark_idx = np.flatnonzero(mask_group)
    reducer = np.add.reduceat(label, mark_idx)
    label[mark_idx[1:]] -= reducer[:-1]
        
def np_label(arr: np.array, group: np.array, replace_zero: bool = True, replace_group: bool = True):
    arr_shift = shift(arr, 1, fill_value=0)
    label = np.where(arr != arr_shift, 1, 0)
    
    if replace_zero:
        mask_zero = arr == 0
        label[mask_zero] = 0
    if replace_group:
        mask_group = group == 0       
        refresh_groups(label, mask_group)  
        
    label = np.cumsum(label)
    
    if replace_zero:
        label[mask_zero] = 0
    if replace_group:
        label[mask_group] = 0

    return label
    
np_label(arr_value, arr_group, False, False)

只需尝试 replace_zeroreplace_group 参数的 4 个不同选项来检查它是否与您的预期输出匹配。

【讨论】:

以上是关于在numpy中标记组内的元素的主要内容,如果未能解决你的问题,请参考以下文章

计算每个组内的元素数

如何内联输入组内的元素(强制它们在一行上)

使用 numpy.where() 在给定范围内的两个数组中搜索元素

怎么删除numpy矩阵内的一些元素

2607. 使子数组元素和相等

Chrome开发人员工具不会检查body标记内的元素(mac OSX)