使用给定的一组索引访问 numpy 数组的连续行

Posted

技术标签:

【中文标题】使用给定的一组索引访问 numpy 数组的连续行【英文标题】:Accessing the sequential rows of a numpy array with a given set of indices 【发布时间】:2022-01-23 03:28:35 【问题描述】:

我有一个 numpy 数组(即 x ),其中每行中丢失的列代表索引号。

import numpy as np
import random 
np.random.seed(0)
x = np.random.random([5,3])
x = np.append(x, np.arange(x.shape[0]).reshape(-1,1), axis=1) 
x=
array([[0.5488135 , 0.71518937, 0.60276338, 0.        ],
       [0.54488318, 0.4236548 , 0.64589411, 1.        ],
       [0.43758721, 0.891773  , 0.96366276, 2.        ],
       [0.38344152, 0.79172504, 0.52889492, 3.        ],
       [0.56804456, 0.92559664, 0.07103606, 4.        ]])

我有另一个 numpy 数组,名为 y,它与第一个数组相关,x 中的每一行在 y 中都有一个用户定义的值相关行。

rep = 4
y = np.random.random([rep*5,3])
array([[0.0871293 , 0.0202184 , 0.83261985],
       [0.77815675, 0.87001215, 0.97861834],
       [0.79915856, 0.46147936, 0.78052918],
       [0.11827443, 0.63992102, 0.14335329],
       [0.94466892, 0.52184832, 0.41466194],
       [0.26455561, 0.77423369, 0.45615033],
       [0.56843395, 0.0187898 , 0.6176355 ],
       [0.61209572, 0.616934  , 0.94374808],
       [0.6818203 , 0.3595079 , 0.43703195],
       [0.6976312 , 0.06022547, 0.66676672],
       [0.67063787, 0.21038256, 0.1289263 ],
       [0.31542835, 0.36371077, 0.57019677],
       [0.43860151, 0.98837384, 0.10204481],
       [0.20887676, 0.16130952, 0.65310833],
       [0.2532916 , 0.46631077, 0.24442559],
       [0.15896958, 0.11037514, 0.65632959],
       [0.13818295, 0.19658236, 0.36872517],
       [0.82099323, 0.09710128, 0.83794491],
       [0.09609841, 0.97645947, 0.4686512 ],
       [0.97676109, 0.60484552, 0.73926358]])

例如,x 中的索引 0 与 y 中的索引 0,1,2,3 相关。

假设调用一个方法后,我从数组x的最后一列得到一个索引集。

ind = my_method(x) #Note that it can be any permutation of number 0 to n-1 where n is the number of rows in x
ind
[4, 0] #For the sake of simplicity, let us assume that the method returns [4,0]

我想知道使用给定索引集访问y 行的最有效方法是什么(例如,当有数百万行时)。例如,如果我有ind = [4,0],那么我想在y 中获取行12,13,14,15,0,1,2,3

预期输出:

       [[0.13818295, 0.19658236, 0.36872517],
       [0.82099323, 0.09710128, 0.83794491],
       [0.09609841, 0.97645947, 0.4686512 ],
       [0.97676109, 0.60484552, 0.73926358],
       [0.0871293 , 0.0202184 , 0.83261985],
       [0.77815675, 0.87001215, 0.97861834],
       [0.79915856, 0.46147936, 0.78052918],
       [0.11827443, 0.63992102, 0.14335329]]

【问题讨论】:

这个问题很难理解。请您提供一个您期望得到的样品吗?也请使用np.random.seed(0),这样每个人都能得到和你一样的随机值。 @richardec 我更新了我的帖子。 我的猜测是np.r_ 会是一种非常快速的方式:***.com/questions/34188620/… 【参考方案1】:
import numpy as np
import random 

np.random.seed(0)

n,m = 10, 20

x = np.random.random([n,m])
x = np.append(x, np.arange(x.shape[0]).reshape(-1,1), axis=1) 

rep = 3

y = np.random.random([rep*n,m])

ind = np.array([0, 2 , 1]) 

选择的ind 表示您需要前九行中的行。

y[:9,]
​
array([[0.31179588, 0.69634349, 0.37775184, 0.17960368, 0.02467873,
        0.06724963, 0.67939277, 0.45369684, 0.53657921, 0.89667129,
        0.99033895, 0.21689698, 0.6630782 , 0.26332238, 0.020651  ,
        0.75837865, 0.32001715, 0.38346389, 0.58831711, 0.83104846],
       [0.62898184, 0.87265066, 0.27354203, 0.79804683, 0.18563594,
        0.95279166, 0.68748828, 0.21550768, 0.94737059, 0.73085581,
        0.25394164, 0.21331198, 0.51820071, 0.02566272, 0.20747008,
        0.42468547, 0.37416998, 0.46357542, 0.27762871, 0.58678435],
       [0.86385561, 0.11753186, 0.51737911, 0.13206811, 0.71685968,
        0.3960597 , 0.56542131, 0.18327984, 0.14484776, 0.48805628,
        0.35561274, 0.94043195, 0.76532525, 0.74866362, 0.90371974,
        0.08342244, 0.55219247, 0.58447607, 0.96193638, 0.29214753],
       [0.24082878, 0.10029394, 0.01642963, 0.92952932, 0.66991655,
        0.78515291, 0.28173011, 0.58641017, 0.06395527, 0.4856276 ,
        0.97749514, 0.87650525, 0.33815895, 0.96157015, 0.23170163,
        0.94931882, 0.9413777 , 0.79920259, 0.63044794, 0.87428797],
       [0.29302028, 0.84894356, 0.61787669, 0.01323686, 0.34723352,
        0.14814086, 0.98182939, 0.47837031, 0.49739137, 0.63947252,
        0.36858461, 0.13690027, 0.82211773, 0.18984791, 0.51131898,
        0.22431703, 0.09784448, 0.86219152, 0.97291949, 0.96083466],
       [0.9065555 , 0.77404733, 0.33314515, 0.08110139, 0.40724117,
        0.23223414, 0.13248763, 0.05342718, 0.72559436, 0.01142746,
        0.77058075, 0.14694665, 0.07952208, 0.08960303, 0.67204781,
        0.24536721, 0.42053947, 0.55736879, 0.86055117, 0.72704426],
       [0.27032791, 0.1314828 , 0.05537432, 0.30159863, 0.26211815,
        0.45614057, 0.68328134, 0.69562545, 0.28351885, 0.37992696,
        0.18115096, 0.78854551, 0.05684808, 0.69699724, 0.7786954 ,
        0.77740756, 0.25942256, 0.37381314, 0.58759964, 0.2728219 ],
       [0.3708528 , 0.19705428, 0.45985588, 0.0446123 , 0.79979588,
        0.07695645, 0.51883515, 0.3068101 , 0.57754295, 0.95943334,
        0.64557024, 0.03536244, 0.43040244, 0.51001685, 0.53617749,
        0.68139251, 0.2775961 , 0.12886057, 0.39267568, 0.95640572],
       [0.18713089, 0.90398395, 0.54380595, 0.45691142, 0.88204141,
        0.45860396, 0.72416764, 0.39902532, 0.90404439, 0.69002502,
        0.69962205, 0.3277204 , 0.75677864, 0.63606106, 0.24002027,
        0.16053882, 0.79639147, 0.9591666 , 0.45813883, 0.59098417]])

您需要的确切索引可以通过输出的 ind[:,np.newaxis]*rep + range(rep) 获得

array([[0, 1, 2],
       [6, 7, 8],
       [3, 4, 5]])

最后,您可以使用以下命令以适当的形式获取所需的索引。

y[ ind[:,None]*rep  + range(rep), :].reshape(-1,m)


array([[0.31179588, 0.69634349, 0.37775184, 0.17960368, 0.02467873,
        0.06724963, 0.67939277, 0.45369684, 0.53657921, 0.89667129,
        0.99033895, 0.21689698, 0.6630782 , 0.26332238, 0.020651  ,
        0.75837865, 0.32001715, 0.38346389, 0.58831711, 0.83104846],
       [0.62898184, 0.87265066, 0.27354203, 0.79804683, 0.18563594,
        0.95279166, 0.68748828, 0.21550768, 0.94737059, 0.73085581,
        0.25394164, 0.21331198, 0.51820071, 0.02566272, 0.20747008,
        0.42468547, 0.37416998, 0.46357542, 0.27762871, 0.58678435],
       [0.86385561, 0.11753186, 0.51737911, 0.13206811, 0.71685968,
        0.3960597 , 0.56542131, 0.18327984, 0.14484776, 0.48805628,
        0.35561274, 0.94043195, 0.76532525, 0.74866362, 0.90371974,
        0.08342244, 0.55219247, 0.58447607, 0.96193638, 0.29214753],
       [0.27032791, 0.1314828 , 0.05537432, 0.30159863, 0.26211815,
        0.45614057, 0.68328134, 0.69562545, 0.28351885, 0.37992696,
        0.18115096, 0.78854551, 0.05684808, 0.69699724, 0.7786954 ,
        0.77740756, 0.25942256, 0.37381314, 0.58759964, 0.2728219 ],
       [0.3708528 , 0.19705428, 0.45985588, 0.0446123 , 0.79979588,
        0.07695645, 0.51883515, 0.3068101 , 0.57754295, 0.95943334,
        0.64557024, 0.03536244, 0.43040244, 0.51001685, 0.53617749,
        0.68139251, 0.2775961 , 0.12886057, 0.39267568, 0.95640572],
       [0.18713089, 0.90398395, 0.54380595, 0.45691142, 0.88204141,
        0.45860396, 0.72416764, 0.39902532, 0.90404439, 0.69002502,
        0.69962205, 0.3277204 , 0.75677864, 0.63606106, 0.24002027,
        0.16053882, 0.79639147, 0.9591666 , 0.45813883, 0.59098417],
       [0.24082878, 0.10029394, 0.01642963, 0.92952932, 0.66991655,
        0.78515291, 0.28173011, 0.58641017, 0.06395527, 0.4856276 ,
        0.97749514, 0.87650525, 0.33815895, 0.96157015, 0.23170163,
        0.94931882, 0.9413777 , 0.79920259, 0.63044794, 0.87428797],
       [0.29302028, 0.84894356, 0.61787669, 0.01323686, 0.34723352,
        0.14814086, 0.98182939, 0.47837031, 0.49739137, 0.63947252,
        0.36858461, 0.13690027, 0.82211773, 0.18984791, 0.51131898,
        0.22431703, 0.09784448, 0.86219152, 0.97291949, 0.96083466],
       [0.9065555 , 0.77404733, 0.33314515, 0.08110139, 0.40724117,
        0.23223414, 0.13248763, 0.05342718, 0.72559436, 0.01142746,
        0.77058075, 0.14694665, 0.07952208, 0.08960303, 0.67204781,
        0.24536721, 0.42053947, 0.55736879, 0.86055117, 0.72704426]])

希望这会有所帮助。我试图让我的答案更笼统。您可以根据需要进行修改。

【讨论】:

【参考方案2】:

我认为你需要类似的东西:

indx = np.array(indx)
rows_in_y = indx[:,np.newaxis]*3 + range(4)
y[rows_in_y,:]

我不确定您要达到什么目的,但这似乎是一个相当正常的索引问题。

【讨论】:

以上是关于使用给定的一组索引访问 numpy 数组的连续行的主要内容,如果未能解决你的问题,请参考以下文章

SQL 查询 - 计算值大于 X 的连续行数

聚合 SQL 中的连续行

员工角色的连续行

Numpy 花式索引

每个客户的连续行之间的Haversine距离

基于正则表达式连接熊猫中的连续行