“获取每组 argmin”的 Numpy 方法
Posted
技术标签:
【中文标题】“获取每组 argmin”的 Numpy 方法【英文标题】:Numpy method for "getting argmin per group" 【发布时间】:2021-12-19 00:38:26 【问题描述】:我正在尝试找到一种高效的 numpy 循环(与 python 循环相反)方法来获取每组成本最低的数据点的索引。类似于np.minimum.at
所做的事情,但使用“argminimum”而不是最小值。 (并且np.argmin.at
不存在)。
以下演示了我正在寻找的内容:
names, groups, costs = zip(*[
('a', 0, 2.0), # no (d is lower cost)
('b', 1, 3.), # yes (tied but first)
('c', 2, 3.), # yes (only one)
('d', 0, 1.2), # yes
('e', 3, 3.), # no (k is lower)
('f', 4, 3.), # no (j is lower)
('g', 5, 3.), # yes
('h', 1, 3.), # no (tied but not first)
('i', 0, 4.), # no (d is lower)
('j', 4, 2.3), # yes
('k', 3, 0.6), # yes
('l', 5, 7.), # no (g is lower)
])
mask = get_minimal_unique_index_mask(arr=np.array(groups), values=np.array(costs))
selected = ''.join(c for c, m in zip(names, mask) if m)
expected = 'bcdgjk'
assert selected == expected, f"Selected: 'selected'. Expected: 'expected'"
我正在尝试找到get_minimal_unique_index_mask
的有效实现。我知道我可以使用 dicts 和 python 循环轻松做到这一点:
def get_minimal_unique_index_mask(groups: Array['N', Any], values: Array['N', float]) -> Array['N', bool]:
min_ixs_vals =
for i, (group, val) in enumerate(zip(groups, values)):
if group not in min_ixs_vals:
min_ixs_vals[group] = i
else:
min_ixs_vals[group] = i if val < values[min_ixs_vals[group]] else min_ixs_vals[group]
argmin_per_group_mask = np.zeros(len(groups), dtype=bool)
argmin_per_group_mask[list(min_ixs_vals.values())] = True
return argmin_per_group_mask
...上述方法有效,但在 python 中循环,因此会很慢。我想知道是否有一个聪明的 numpy 方法来做同样的事情。
【问题讨论】:
我可能会为此使用 pandas 的 groupby 功能。 谢谢 Quang,我试过了,现在似乎有一个工作功能。 【参考方案1】:嗯,我想出了如何用 Pandas 来做这件事。不确定循环是否在高效的 C 代码中:
def get_minimal_unique_index_mask(groups: Array['N', Any], values: Array['N', float]) -> MaskArray:
ixs = pd.DataFrame('groups': groups, 'values': values).groupby('groups')['values'].idxmin().values.astype(int)
argmin_per_group_mask = np.zeros(len(groups), dtype=bool)
argmin_per_group_mask[ixs] = True
return argmin_per_group_mask
感谢this answer的马科斯
【讨论】:
【参考方案2】:原来熊猫版本很慢。仅在 python 中循环要快 60 倍,与在 C++ 中实现相比可能仍然很慢
def get_minimal_unique_index_mask(groups: Array['N', Any], values: Array['N', float]) -> MaskArray:
group_to_min_ix_val =
for i, (g, v) in enumerate(zip(groups, values)):
g = str(g)
if g not in group_to_min_ix_val:
group_to_min_ix_val[g] = (i, v)
else:
_, vmin = group_to_min_ix_val[g]
if v < vmin:
group_to_min_ix_val[g] = (i, v)
argmin_per_group_mask = np.zeros(groups.shape, dtype=bool)
for i, _ in group_to_min_ix_val.values():
argmin_per_group_mask[i] = True
return argmin_per_group_mask
【讨论】:
以上是关于“获取每组 argmin”的 Numpy 方法的主要内容,如果未能解决你的问题,请参考以下文章