Numpy:在给定索引的情况下,如何以有效的方式摆脱轴 = 1 的最小值?
Posted
技术标签:
【中文标题】Numpy:在给定索引的情况下,如何以有效的方式摆脱轴 = 1 的最小值?【英文标题】:Numpy: How to get rid of the minima along axis=1, given the indices - in an efficient way? 【发布时间】:2012-03-11 05:09:14 【问题描述】:给定一个形状为(1000000,6)
的矩阵A,我已经想出了如何为每一行获取最小最右边的值并在这个函数中实现它:
def calculate_row_minima_indices(h): # h is the given matrix.
"""Returns the indices of the rightmost minimum per row for matrix h."""
flipped = numpy.fliplr(h) # flip the matrix to get the rightmost minimum.
flipped_indices = numpy.argmin(flipped, axis=1)
indices = numpy.array([2]*dim) - flipped_indices
return indices
indices = calculate_row_minima_indices(h)
for col, row in enumerate(indices):
print col, row, h[col][row] # col_index, row_index and value of minimum which should be removed.
每一行都有一个最小值。所以我需要知道的是删除具有最小值的条目并将形状为(1000000,6)
的矩阵收缩为带有的矩阵形状(1000000,5)
。
我会生成一个具有较低维度的新矩阵,并使用 for 循环使用我希望它携带的值填充它,但我担心运行时。 那么有没有一些内置的方法或技巧可以通过每行的最小值来缩小矩阵?
也许这个信息是有用的:这些值都大于或等于 0.0。
【问题讨论】:
【参考方案1】:可以使用布尔掩码数组进行选择,但内存使用量有点大。
import numpy
h = numpy.random.randint(0, 10, (20, 6))
flipped = numpy.fliplr(h) # flip the matrix to get the rightmost minimum.
flipped_indices = numpy.argmin(flipped, axis=1)
indices = 5 - flipped_indices
mask = numpy.ones(h.shape, numpy.bool)
mask[numpy.arange(h.shape[0]), indices] = False
result = h[mask].reshape(-1, 5)
【讨论】:
【参考方案2】:假设您有足够的内存来保存原始数组和新数组的形状的布尔掩码,这是一种方法:
import numpy as np
def main():
np.random.seed(1) # For reproducibility
data = generate_data((10, 6))
indices = rightmost_min_col(data)
new_data = pop_col(data, indices)
print 'Original data...'
print data
print 'Modified data...'
print new_data
def generate_data(shape):
return np.random.randint(0, 10, shape)
def rightmost_min_col(data):
nrows, ncols = data.shape[:2]
min_indices = np.fliplr(data).argmin(axis=1)
min_indices = (ncols - 1) - min_indices
return min_indices
def pop_col(data, col_indices):
nrows, ncols = data.shape[:2]
col_indices = col_indices[:, np.newaxis]
row_indices = np.arange(ncols)[np.newaxis, :]
mask = col_indices != row_indices
return data[mask].reshape((nrows, ncols-1))
if __name__ == '__main__':
main()
这会产生:
Original data...
[[5 8 9 5 0 0]
[1 7 6 9 2 4]
[5 2 4 2 4 7]
[7 9 1 7 0 6]
[9 9 7 6 9 1]
[0 1 8 8 3 9]
[8 7 3 6 5 1]
[9 3 4 8 1 4]
[0 3 9 2 0 4]
[9 2 7 7 9 8]]
Modified data...
[[5 8 9 5 0]
[7 6 9 2 4]
[5 2 4 4 7]
[7 9 1 7 6]
[9 9 7 6 9]
[1 8 8 3 9]
[8 7 3 6 5]
[9 3 4 8 4]
[0 3 9 2 4]
[9 7 7 9 8]]
我在这里使用的可读性较低的技巧之一是在数组比较期间利用 numpy 的广播。举个简单的例子,考虑以下几点:
import numpy as np
a = np.array([[1, 2, 3]])
b = np.array([[1],[2],[3]])
print a == b
这会产生:
array([[ True, False, False],
[False, True, False],
[False, False, True]], dtype=bool)
因此,如果我们知道要删除的项目的列索引,我们可以对列索引数组的操作进行向量化,这就是 pop_col
所做的。
【讨论】:
你可以做mask = np.ones(data.shape, 'bool'); mask[np.arange(nrows), col_indices] = False
。这可能更具可读性。
不过,它不会做同样的事情。
它似乎做同样的事情,你能解释一下它们有什么不同吗?
你说得对,我昨晚有点紧张。感谢您的建议!以上是关于Numpy:在给定索引的情况下,如何以有效的方式摆脱轴 = 1 的最小值?的主要内容,如果未能解决你的问题,请参考以下文章
如何在没有科学记数法和给定精度的情况下漂亮地打印 numpy.array?
使用 NumPy 从矩阵中获取最小/最大 n 值和索引的有效方法
两个 1D numpy / torch 数组的特殊索引以生成另一个数组