“获取每组 argmin”的 Numpy 方法

Posted

技术标签:

【中文标题】“获取每组 argmin”的 Numpy 方法【英文标题】:Numpy method for "getting argmin per group" 【发布时间】:2021-12-19 00:38:26 【问题描述】:

我正在尝试找到一种高效的 numpy 循环(与 python 循环相反)方法来获取每组成本最低的数据点的索引。类似于np.minimum.at 所做的事情,但使用“argminimum”而不是最小值。 (并且np.argmin.at 不存在)。

以下演示了我正在寻找的内容:

    names, groups, costs = zip(*[
        ('a', 0, 2.0),  # no (d is lower cost)
        ('b', 1, 3.),  # yes (tied but first)
        ('c', 2, 3.),  # yes (only one)
        ('d', 0, 1.2),  # yes
        ('e', 3, 3.),  # no (k is lower)
        ('f', 4, 3.),  # no (j is lower)
        ('g', 5, 3.),  # yes
        ('h', 1, 3.),  # no (tied but not first)
        ('i', 0, 4.),  # no (d is lower)
        ('j', 4, 2.3),  # yes
        ('k', 3, 0.6),  # yes
        ('l', 5, 7.),  # no (g is lower)
    ])
    mask = get_minimal_unique_index_mask(arr=np.array(groups), values=np.array(costs))
    selected = ''.join(c for c, m in zip(names, mask) if m)
    expected = 'bcdgjk'
    assert selected == expected, f"Selected: 'selected'.  Expected: 'expected'"

我正在尝试找到get_minimal_unique_index_mask 的有效实现。我知道我可以使用 dicts 和 python 循环轻松做到这一点:

def get_minimal_unique_index_mask(groups: Array['N', Any], values: Array['N', float]) -> Array['N', bool]:
    min_ixs_vals = 
    for i, (group, val) in enumerate(zip(groups, values)):
        if group not in min_ixs_vals:
            min_ixs_vals[group] = i
        else:
            min_ixs_vals[group] = i if val < values[min_ixs_vals[group]] else min_ixs_vals[group]
    argmin_per_group_mask = np.zeros(len(groups), dtype=bool)
    argmin_per_group_mask[list(min_ixs_vals.values())] = True
    return argmin_per_group_mask

...上述方法有效,但在 python 中循环,因此会很慢。我想知道是否有一个聪明的 numpy 方法来做同样的事情。

【问题讨论】:

我可能会为此使用 pandas 的 groupby 功能。 谢谢 Quang,我试过了,现在似乎有一个工作功能。 【参考方案1】:

嗯,我想出了如何用 Pandas 来做这件事。不确定循环是否在高效的 C 代码中:

def get_minimal_unique_index_mask(groups: Array['N', Any], values: Array['N', float]) -> MaskArray:
    ixs = pd.DataFrame('groups': groups, 'values': values).groupby('groups')['values'].idxmin().values.astype(int)
    argmin_per_group_mask = np.zeros(len(groups), dtype=bool)
    argmin_per_group_mask[ixs] = True
    return argmin_per_group_mask

感谢this answer的马科斯

【讨论】:

【参考方案2】:

原来熊猫版本很慢。仅在 python 中循环要快 60 倍,与在 C++ 中实现相比可能仍然很慢

def get_minimal_unique_index_mask(groups: Array['N', Any], values: Array['N', float]) -> MaskArray:
    group_to_min_ix_val = 
    for i, (g, v) in enumerate(zip(groups, values)):
        g = str(g)
        if g not in group_to_min_ix_val:
            group_to_min_ix_val[g] = (i, v)
        else:
            _, vmin = group_to_min_ix_val[g]
            if v < vmin:
                group_to_min_ix_val[g] = (i, v)

    argmin_per_group_mask = np.zeros(groups.shape, dtype=bool)
    for i, _ in group_to_min_ix_val.values():
        argmin_per_group_mask[i] = True

    return argmin_per_group_mask

【讨论】:

以上是关于“获取每组 argmin”的 Numpy 方法的主要内容,如果未能解决你的问题,请参考以下文章

使用 NumPy 从矩阵中获取最小/最大 n 值和索引的有效方法

获取多维 NumPy 数组中最大值的位置

Numpy 索引,获取宽度为 2 的波段

在 2D numpy 数组的每个滚动窗口中获取最大值

获取numpy数组中元素的索引

使用 numpy.max/ numpy.min 作为时间戳值