给定一个 numpy 数组视图中项目的索引,在基本数组中找到它的索引

Posted

技术标签:

【中文标题】给定一个 numpy 数组视图中项目的索引,在基本数组中找到它的索引【英文标题】:given the index of an item in a view of a numpy array, find its index in the base array 【发布时间】:2021-12-31 18:49:47 【问题描述】:

假设a 是一个形状为(N,)b = a[k:l] 的numpy 数组。我知道x = b[i],有没有办法在不知道kl 并且不搜索a 的情况下找到jx = a[j] x

例如a = np.array([2,4,3,1,7])b = a[1:4]。我只能访问b,但想知道3a 中的索引是什么,知道它在b 中的索引是1

当然,我可以使用b.base 访问a,然后在a 中搜索项目3,但我想知道是否有一个附加到视图的方法,它返回项目的索引基本数组。

【问题讨论】:

这很难理解。您能否提供一个示例数组和您期望的输出?我们不是计算机,所以所有这些变量都很难推理:) numpy 数组中没有内置任何东西可以做到这一点。以我对形状、步幅和数据缓冲区指针的了解,我可以在大多数情况下解决这个问题,但这并不是一项简单的任务。 【参考方案1】:

正如@hpaulj 已经在 cmets 中声明的那样,没有内置功能可以这样做。但是您仍然可以根据dtype 的大小以及基址和视图之间的字节偏移量来计算基址的索引。可以从属性ndarray.__array_interface__['data'][0]获取字节偏移量

import numpy as np
import unittest

def baseIndex(array: np.ndarray, index: int) -> int:
    base = array.base
    if base is None:
        return index
    size = array.dtype.itemsize
    stride = array.strides[0] // size
    offset = (array.__array_interface__['data'][0] - base.__array_interface__['data'][0]) // size
    return offset + index * stride

a = np.array([0,1,2,3,4,5,6])
b = a
class Test(unittest.TestCase):

    def test_1_simple(self):
        """b = a"""
        b = a
        i = 1
        j = baseIndex(b, i)
        self.assertEqual(a[j], b[i])
    
    def test_2_offset(self):
        """b = a[3:]"""
        b = a[3:]
        i = 1
        j = baseIndex(b, i)
        self.assertEqual(a[j], b[i])
    
    def test_3_strided(self):
        """b = a[1::2]"""
        b = a[1::2]
        i = 1
        j = baseIndex(b, i)
        self.assertEqual(a[j], b[i])
    
    def test_4_reverse_strided(self):
        """b = a[4::-2]"""
        b = a[4::-2]
        i = 1
        j = baseIndex(b, i)
        self.assertEqual(a[j], b[i])


unittest.main(verbosity=2)

输出:

test_1_simple (__main__.Test)
b = a ... ok
test_2_offset (__main__.Test)
b = a[3:] ... ok
test_3_strided (__main__.Test)
b = a[1::2] ... ok
test_4_reverse_strided (__main__.Test)
b = a[4::-2] ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.001s

OK

编辑:我现在更新了函数来处理b 不连续和/或反向的情况,感谢@Jérôme Richard 发现了这一点。此外,正如@mozway 所说,ndarray.__array_interface__ 是一个内部 numpy 细节,可能会在没有通知的情况下更改,但到目前为止我还没有看到任何其他方法。

【讨论】:

很好,但是这种方法应该谨慎使用,因为 numpy 的内部方法可能会在没有警告的情况下发生变化。 这就是我想到的那种计算方式。获取offset 可能是最“高级”的一步。 有趣的策略。但是,它并不总是有效,因为数组可以是不连续的,或者数组甚至可以有一个负步幅。你可以用np.flip(np.array([0,1,2,3,4,5,6]))试试,结果不对。 @JérômeRichard 你说得对,我忘了考虑这种情况。我现在已经更新了解决方案。

以上是关于给定一个 numpy 数组视图中项目的索引,在基本数组中找到它的索引的主要内容,如果未能解决你的问题,请参考以下文章

使用给定的一组索引访问 numpy 数组的连续行

python numpy maxpool:给定一个数组和来自argmax的索引,返回最大值

如何在给定索引列表的情况下有效地更新 numpy ndarray

Numpy 布尔型数组

两个 1D numpy / torch 数组的特殊索引以生成另一个数组

交错两个 numpy 索引数组,每个数组中的一项