使用 Numpy 对除第一个维度之外的所有维度进行平面索引

Posted

技术标签:

【中文标题】使用 Numpy 对除第一个维度之外的所有维度进行平面索引【英文标题】:Flat indexing of all but first dimension with Numpy 【发布时间】:2022-01-20 07:11:30 【问题描述】:

有没有办法通过 NumPy 对剩余维度使用平面索引?我正在尝试将以下 MATLAB 函数转换为 Python

function [indices, weights] = locate(values, gridpoints)
    indices = ones(size(values));
    weights = zeros([2, size(values)]);

    for ix = 1:numel(values)
        if values(ix) <= gridpoints(1)
            indices(ix) = 1;
            weights(:, ix) = [1; 0];
        elseif values(ix) >= gridpoints(end)
            indices(ix) = length(gridpoints) - 1;
            weights(:, ix) = [0; 1];
        else
            indices(ix) = find(gridpoints <= values(ix), 1, 'last');    
            weights(:, ix) = ...
                [gridpoints(indices(ix) + 1) - values(ix); ...
                 values(ix) - gridpoints(indices(ix))] ...
                / (gridpoints(indices(ix) + 1) - gridpoints(indices(ix)));
        end
    end
end

但我无法理解与 MATLAB 的 weights(:, ix) 等效的 NumPy 是什么——也就是说,仅在剩余维度中进行线性索引。

我希望可以直接翻译语法,但假设 values 是一个 3×4 数组,那么 weights 变成一个 2×3×4 数组。在 MATLAB 中,weights(:, ix) 是一个 2×1 数组,而在 Python 中 weights[:, ix] 是一个 2×3 数组。

我认为我已经处理了下面函数中的所有其他内容。

import numpy as np


def locate(values, gridpoints):
    indices = np.zeros(np.shape(values), dtype=int)
    weights = np.zeros((2,) + np.shape(values))

    for ix in range(values.size):
        if values.flat[ix] <= gridpoints[0]:
            indices.flat[ix] = 0
            # weights[:, ix] = [1, 0]
        elif values.flat[ix] >= gridpoints[-1]:
            indices.flat[ix] = gridpoints.size - 2
            # weights[:, ix] = [0, 1]
        else:
            indices.flat[ix] = (
                np.argwhere(gridpoints <= values.flat[ix]).flatten()[-1]
            )
            # weights[:, ix] = (
            #         np.array([gridpoints[indices.flat[ix] + 1] - values.flat[ix],
            #                   values.flat[ix] - gridpoints[indices.flat[ix]]])
            #         / (gridpoints[indices.flat[ix] + 1] - gridpoints[indices.flat[ix]])
            # )

    return indices, weights

你有什么建议吗?也许我只是想问题全错了。我也尝试尽可能简单地编写代码,因为我打算稍后使用 Numba 来加速它。

【问题讨论】:

ix 只是一个标量,因此 MATLAB 中 x(:,ix) 的等价物在 Python 中将是 x[:,ix] (注意由于 1 与 0-基于索引)。这里没有剩余的维度或其他东西。在 MATLAB 中,这会从二维矩阵中获取列 ix,这在 NumPy 中也是微不足道的。 @Adriaan 不完全是。假设在 MATLAB (Python) 中 values = zeros(3, 4) (values = np.zeros((3, 4))。那么 weights 是一个 2×3×4 数组,weights(:, 1) 是一个 2×1 数组 (weights[:, 1] 一个 2× 3个数组)。 能否请您edit 提出minimal reproducible example 的问题? IE。除了这个索引问题,把所有东西都拿出来,并添加一些示例数据(例如rand(3,4)),显示 MATLAB 做了什么以及你目前在 Python 中获得了什么?所有循环和if 语句似乎与这个索引问题无关,因此更难掌握。 没有直接的numpy 等效项。是 MATLAB 在这里玩有趣的游戏。 @hpaulj 这些“有趣的游戏”只是线性索引,非常有用,我在使用 NumPy 时经常会错过。 【参考方案1】:

根据hpaulj 的comment,似乎没有直接的 NumPy 等效项。在缺乏的情况下,我能想到的最好的方法是按照下面的代码和NumPy for Matlab Users 的建议重塑weights 数组。

import numpy as np


def locate(values, gridpoints):
    indices = np.zeros(values.shape, dtype=int)
    weights = np.zeros((2, values.size))  # Temporarily make weights 2-by-N

    for ix in range(values.size):
        if values.flat[ix] <= gridpoints[0]:
            indices.flat[ix] = 0
            weights[:, ix] = [1, 0]
        elif values.flat[ix] >= gridpoints[-1]:
            indices.flat[ix] = gridpoints.size - 2
            weights[:, ix] = [0, 1]
        else:
            indices.flat[ix] = (
                np.argwhere(gridpoints <= values.flat[ix]).flatten()[-1]
            )
            weights[:, ix] = (
                    np.array([gridpoints[indices.flat[ix] + 1] - values.flat[ix],
                              values.flat[ix] - gridpoints[indices.flat[ix]]])
                    / (gridpoints[indices.flat[ix] + 1] - gridpoints[indices.flat[ix]])
            )
    
    # Give weights correct dimensions
    weights.shape = (2,) + values.shape
    
    return indices, weights

【讨论】:

这确实是 AFAIK 的唯一解决方案。 Python 不做线性索引,这是你在 MATLAB 中在 2 维和 3 维中所做的。

以上是关于使用 Numpy 对除第一个维度之外的所有维度进行平面索引的主要内容,如果未能解决你的问题,请参考以下文章

[Tips] numpy 收缩维度

在 Python NumPy 中,维度和轴是啥?

如何使用 numpy 或 pandas 创建(或更改)数组/列表的维度?

numpy 插入一个元素,保持数组维度的数量

在进行矩阵工作时如何理解循环和额外 numpy 维度之间的权衡?

numpy数组的堆叠:numpy.stack, numpy.hstack, numpy.vstack