Kd 树迭代实现(C++)

Posted

技术标签:

【中文标题】Kd 树迭代实现(C++)【英文标题】:Kd Tree Iterative implementation ( C++ ) 【发布时间】:2011-06-10 13:57:32 【问题描述】:

你好,有没有人用 C++ 迭代实现 Kd-Tree。 我试过了,但是当节点数是奇数时它失败了。 到目前为止,这是我的代码。详情请参考http://ldots.org/kdtree/#buildingAkDTree网站。

#include <stdio.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <algorithm>
#include <stack>
#include <queue>
#include <iomanip>


struct Point 
    double pt[2];
    int id;
;

typedef std::vector<Point> TPointVector;

struct KdNode 
    double point[2];
    int id;
    double desc;
    bool leaf;

    KdNode *left;
    KdNode *right;
    KdNode *parent;
    KdNode(KdNode *parent_):parent(parent_),leaf(false)
    KdNode(KdNode *parent_,TPointVector::iterator itr, int depth, TPointVector &pv);
    KdNode(KdNode *t, TPointVector &pv);

;

KdNode::KdNode(KdNode *parent_,TPointVector::iterator itr, int depth, TPointVector &pv) 
    parent = parent_ ;
    left   = 0;
    right  = 0;
    desc   = itr->pt[depth % 2 ];
    leaf   = false;


KdNode::KdNode(KdNode *t, TPointVector &pv) 
    id       = pv[0].id;
    point[0] = pv[0].pt[0];
    point[1] = pv[0].pt[1];
    left     = 0;
    right    = 0;
    parent   = t;
    leaf     = true;


KdNode *pRoot = 0;


struct ComparePoints 
    int cord;
    ComparePoints(int  cord_) : cord(cord_ % 2)  ;
    bool operator()(const Point& lhs, const Point& rhs) const 
        return lhs.pt[cord] < rhs.pt[cord];
    
;
void buildLeftTree(std::stack<TPointVector > &stackL) 
    KdNode *pCurrent = pRoot;
    KdNode **pNode  = &(pCurrent->left);
    int depth      = 0; 
    bool changeDirection = false;
    while (! stackL.empty()) 
        TPointVector pv = stackL.top(); 
        stackL.pop();
        if ( pv.size() != 1 )  
            std::sort(pv.begin(), pv.end(), ComparePoints(++depth));

            *pNode = new KdNode(pCurrent, pv.begin() + pv.size()/2, depth, pv);

            TPointVector lvp,rvp;
            std::size_t median = pv.size() / 2;
            std::copy(pv.begin(), pv.begin() + median, std::back_inserter(lvp));
            std::copy(pv.begin() + median, pv.end(), std::back_inserter(rvp));

            stackL.push(rvp); 
            stackL.push(lvp);

            if ( changeDirection ) 
                pCurrent = pCurrent->right;
                changeDirection = false;
             else 
                pCurrent = pCurrent->left;
                       
            pNode = &(pCurrent->left);

         else 
            KdNode **pNodeLeft   = &(pCurrent->left);
            *pNodeLeft  = new KdNode(pCurrent, pv);
            pv = stackL.top();
            stackL.pop();

            KdNode **pNodeRight   = &(pCurrent->right);
            *pNodeRight  = new KdNode(pCurrent,pv);

            pCurrent = pCurrent->parent;
            pNode  = &(pCurrent->right);
            changeDirection = true;
            depth--;
                   
    


void buildRightTree(std::stack<TPointVector > &stackR) 
    KdNode *pCurrent = pRoot;
    KdNode **pNode  = &(pCurrent->right);
    int depth      = 0; 
    bool changeDirection = true;
    while (! stackR.empty()) 
        TPointVector pv = stackR.top(); 
        stackR.pop();

        if ( pv.size() != 1 )  
            std::sort(pv.begin(), pv.end(), ComparePoints(++depth));
            *pNode = new KdNode(pCurrent, pv.begin() + pv.size()/2, depth, pv);

            TPointVector lvp,rvp;
            std::size_t median = pv.size() / 2;
            std::copy(pv.begin(), pv.begin() + median, std::back_inserter(lvp));
            std::copy(pv.begin() + median, pv.end(), std::back_inserter(rvp));

            stackR.push(rvp); 
            stackR.push(lvp);       

            if ( changeDirection ) 
                pCurrent = pCurrent->right;
                changeDirection = false;
             else 
                pCurrent = pCurrent->left;
                   
            pNode = &(pCurrent->left);

         else 
            KdNode **pNodeLeft   = &(pCurrent->left);
            *pNodeLeft  = new KdNode(pCurrent, pv);
            pv = stackR.top();
            stackR.pop();

            KdNode **pNodeRight   = &(pCurrent->right);
            *pNodeRight  = new KdNode(pCurrent,pv);

            pCurrent = pCurrent->parent;
            pNode  = &(pCurrent->right);
            depth--;
            changeDirection = true;
                   
    



void constructKD(TPointVector &pv) 
    int depth = 0;
    std::sort(pv.begin(), pv.end(), ComparePoints(depth));

    pRoot        = new KdNode(0);
    pRoot->desc  = ( pv.begin() + pv.size()/2)->pt[0];
    pRoot->left  = 0;
    pRoot->right = 0;

    TPointVector lvp, rvp;
    std::copy(pv.begin(), pv.begin() + pv.size()/2, std::back_inserter(lvp));
    std::copy(pv.begin() + pv.size()/2, pv.end(), std::back_inserter(rvp));

    std::stack<TPointVector > stackL, stackR;
    stackL.push(lvp);
    stackR.push(rvp);

    buildLeftTree(stackL);
    buildRightTree(stackR);


void readPoints(const char* fileName, TPointVector& points) 
    std::ifstream input(fileName);

    if ( input.peek() != EOF ) 
        while(!input.eof()) 
            int id = 0;
            double x_cord, y_cord;
            input >> id >> x_cord >> y_cord;

            Point t ;
            t.pt[0] = x_cord;
            t.pt[1] = y_cord;
            t.id    = id;

            points.push_back(t);
        
        input.close();
       

void _printLevelWise(KdNode *node, std::queue<KdNode *> Q) 
    int depth = 0;
    while ( ! Q.empty()) 
        KdNode *qNode = Q.front();Q.pop();
        if ( qNode->leaf ) 
            std::cout << "[" << qNode->id << "]" << std::setprecision (25) << "(" << qNode->point[0] << "," << qNode->point[1] << ")" << std::endl;
         else 
            std::cout << std::setprecision (25) << qNode->desc << std::endl;
               
        if (qNode->left != 0)
            Q.push(qNode->left);
        if (qNode->right != 0)
            Q.push(qNode->right);
    

void PrintLevelWise(KdNode *node) 
    std::queue<KdNode *> Q;
    Q.push(node);
    _printLevelWise(node, Q);

int main ( int argc, char **argv ) 
    if ( argc <= 1 ) 
        return 0;
    
    TPointVector points;
    readPoints(argv[1], points);
    for ( TPointVector::iterator itr = points.begin(); itr != points.end(); ++itr) 
        std::cout << "(" << itr->pt[0] << "," << itr->pt[1] << ")" << std::endl;
    
    if ( points.size() == 0 )
        return 0;
    constructKD(points);
    PrintLevelWise(pRoot);
    std::cout << "Construction of KD Tree Done " << std::endl;

失败的示例输入:

1 6 1 
2 5 5 
3 9 6 
4 3 6 
5 4 9

适用的示例输入:

1 6 1 
2 5 5 
3 9 6 
4 3 6 
5 4 9 
6 4 0 
7 7 9 
8 2 9

【问题讨论】:

奇数时如何“失败”?段错误? 更新问题,抱歉没有放详细信息。 对于相关方来说,“失败”是省略了读取的节点之一,即只有 N-1 个节点最终出现在 N 个条目的树中(当 N 为奇数时)。 【参考方案1】:

buildLeftTreebuildRightTree 中的 else 无法处理右子树上的节点数为奇数的情况。在您的 5 点示例中,buildRightTree 中的 else 案例在 stackR 上以三个点结束,第一个用于left 节点,第二个它静默分配给right 节点,就好像它是唯一的节点。

这是由于您的中位数选择使用的标准与您引用的网站上列出的标准不同。

std::size_t median = pv.size() / 2; // degenerates in cases where size() is odd

您的选择标准应该基于 中位数 x 或 y 值,并根据该标准使用子列表(不假定任何给定大小)。

【讨论】:

我将进行此更改并重新测试,但我还有另一个问题,为什么没有可用的迭代 KD-Tree 实现。是否有可用的 KD-tree 实现的简单版本(递归可能没问题)【参考方案2】:

buildLeftTreebuildRightTreeelse 中。

添加

if (pv.size() > 1)

    pNode = &(pCurrent->right);
    continue;

在堆栈的toppop 函数之间。会好的。

我还有一个问题,点数百万的时候,建树会很慢,如何提高性能?

【讨论】:

以上是关于Kd 树迭代实现(C++)的主要内容,如果未能解决你的问题,请参考以下文章

K 近邻算法(KNN)与KD 树实现

kd树原理及实现

python kd树 搜索

机器学习100天(三十三):033 KD树的Python实现

KD树

二叉搜索树的详细实现(C++)