如何在 Matlab 中进行高效的 k 最近邻计算

Posted

技术标签:

【中文标题】如何在 Matlab 中进行高效的 k 最近邻计算【英文标题】:How to do efficient k-nearest neighbor calculation in Matlab 【发布时间】:2014-08-08 10:31:42 【问题描述】:

我正在使用 Matlab 中的 k-最近邻算法进行数据分析。我的数据由大约 11795 x 88 数据矩阵组成,其中行是观察值,列是变量。

我的任务是为 n 个选定的测试点找到 k 最近邻。目前我正在使用以下逻辑:

对于所有的测试点

   LOOP all the data and find the k-closest neighbors (by euclidean distance)

换句话说,我循环了所有 n 个测试点。对于每个测试点,我通过欧几里得距离搜索数据(不包括测试点本身)中的 k 最近邻。对于每个测试点,这大约需要 k x 11794 次迭代。所以整个过程大约需要 n x k x 11794 次迭代。如果 n = 10000 且 k = 7,这将是大约 8.256 亿次迭代。

有没有更有效的方法来计算 k 近邻?现在大部分计算都会浪费,因为我的算法很简单:

计算到所有其他点的欧几里德距离,选取最近的点并排除最近的点,不再考虑 --> 计算到所有其他点的欧几里德距离并选取最近的点 --> 等等 -->等等。

有没有一种聪明的方法可以摆脱这种“浪费计算”?

目前这个过程在我的电脑上大约需要 7 个小时(3.2 GHz,8 GB RAM,64 位 Win 7)... :(

以下是一些明确说明的逻辑(这不是我的全部代码,但这是消耗性能的部分):

for i = 1:size(testpoints, 1) % Loop all the test points 
    neighborcandidates = all_data_excluding_testpoints; % Use the rest of the data excluding the test points in search of the k-nearest neighbors 
    testpoint = testpoints(i, :); % This is the test point for which we find k-nearest neighbors
    kneighbors = []; % Store the k-nearest neighbors here.
    for j = 1:k % Find k-nearest neighbors
        bdist = Inf; % The distance of the closest neighbor
        bind = 0; % The index of the closest neighbor
        for n = 1:size(neighborcandidates, 1) % Loop all the candidates
            if pdist([testpoint; neighborcandidates(n, :)]) < bdist % Check the euclidean distance
                bdist = pdist([testpoint; neighborcandidates(n, :)]); % Update the best distance so far
                bind = n; % Save the best found index so far
            end
        end
        kneighbors = [kneighbors; neighborcandidates(bind, :)]; % Save the found neighbour
        neighborcandidates(bind, :) = []; % Remove the neighbor from further consideration 
    end
end

【问题讨论】:

加个小例子说明清楚。 有很多循环——如果你只是在整个矩阵上运行pdist2 作为一个输入,然后将n 观察的子集作为第二个输入矩阵,会发生什么?你的电脑能处理吗/你知道这需要多长时间吗?因为这样您就可以在一行中获得您正在寻找的所有元素的成对距离,并为每个 n 找到顶部的 n 观察结果应该非常简单...... 嗨@Dan 我用pdist2来计算距离。只用了不到一分钟。休息应该没问题=)所以这是一个显着的改进=) @jjepsuomi 没问题,我已经添加了一个答案,展示了我将如何使用它 @jjepsuomi 另请参阅我对使用 Matlab 内置的答案knnsearch 【参考方案1】:

使用pdist2

A = rand(20,5);             %// This is your 11795 x 88
B = A([1, 12, 4, 8], :);    %// This is your n-by-88 subset, i.e. n=4 in this case
n = size(B,1);

D = pdist2(A,B);
[~, ind] = sort(D);
kneighbours = ind(2:2+k, :);

现在您可以使用kneighbours 来索引A 中的一行。注意kneighbours的列对应B的行

但既然您已经使用 pdist 进入统计工具箱,为什么不直接使用 Matlab 的 knnsearch

kneighbours_matlab = knnsearch(A,B,'K',k+1);

注意kneighbourskneighbours_matlab(:,2:end)' 相同

【讨论】:

+1 感谢您的帮助! =) 当我完成解决方案的实施后,我将发布预计的运行时间 =) 嗨@Dan 我决定使用pdist2 方法。运行时间现在大约是 30 秒 =) 所以运行时间大约快 x840 倍 x)【参考方案2】:

我不熟悉特定的 matlab 函数,但您可以从公式中删除 k。

有一个众所周知的选择算法

    将数组 A(大小为 n)和数字 k 作为输入。 给出数组 A 的排列,使得第 k 个最大/最小元素位于第 k 个位置。 较小的元素位于左侧,较大的元素位于右侧。

例如

A=2,4,6,8,10,1,3,5,7,9; k=5

output = 2,4,1,3,5,10,6,8,7,9

这是在 O(n) 步中完成的,不依赖于 k。

EDIT1:您还可以预先计算所有距离,因为它看起来像是您花费大部分计算的地方。这将是一个大约 800M 的矩阵,所以这不应该是现代机器上的问题。

【讨论】:

现在我看了一下这个问题。您应该首先尝试 EDIT 建议,因为它更易于实施。请记住 dist[i,j] = dist[j,i] +1 感谢您的帮助! =)我会尝试建议,让大家知道时间改进=)【参考方案3】:

我不确定它是否会加速代码,但它删除了内部的两个循环

for i = 1:size(testpoints, 1) % //Loop all the test points 
    temp = repmat(testpoints(i,:),size(neighborcandidates, 1),1);
    euclead_dist = (sum((temp - neighborcandidates).^2,2).^(0.5));
    [sort_dist ind] = sort(euclead_dist);
    lowest_k_ind = ind(1:k);
    kneighbors = neighborcandidates(lowest_k_ind, :);
    neighborcandidates(lowest_k_ind, :) = [];
end

【讨论】:

【参考方案4】:

这不行吗?

adjk = adj;

for i=1:k-1 
adj_k = adj_k*adj; 
end

kneigh = find(adj_k(n,:)>0)

给定一个节点 n 和一个索引 k?

【讨论】:

【参考方案5】:

也许这是在 Matlab 上下文中更快的代码。您还可以尝试并行函数、数据索引和近似最近邻算法,以提高理论上的效率。

% a slightly faster way to find k nearest neighbors in matlab
% find neighbors for data Y from data X

m=size(X,1);
n=size(Y,1);
IDXs_out=zeros(n,k);

distM=(repmat(X(:,1),1,n)-repmat(Y(:,1)',m,1)).^2;
for d=2:size(Y,2)
    distM=distM+(repmat(X(:,d),1,n)-repmat(Y(:,d)',m,1)).^2;
end
distM=sqrt(distM);
for i=1:k
    [~,idx]=min(distM,[],1);
    id=sub2ind(size(distM),idx',(1:n)');
    distM(id)=inf;
    IDXs_out(:,i)=idx';
end

【讨论】:

以上是关于如何在 Matlab 中进行高效的 k 最近邻计算的主要内容,如果未能解决你的问题,请参考以下文章

如何在高维数据中高效地找到k近邻?

matlab中的k最近邻分类器

情感识别基于K近邻分类算法的语音情感识别matlab 源码

如何根据一组 k 最近邻计算平均值?

K近邻(KNN)算法是基于实例的算法,如果训练样本数量庞大,预测的时候挨个计算距离效率会很低下,如何破解?

Matlab计算数组中所有(u,v)向量的最近邻距离