计算图(graph)的遍历

Posted 阳光玻璃杯

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了计算图(graph)的遍历相关的知识,希望对你有一定的参考价值。

很久没有写博客了,忙并不是借口,懒才是理由。
一直想重构CupCnn,写成一个通用的计算图,能随意搭建各种神经网络(CupCnn只能搭建有一个链路的有序的神经网络),然后把名字也改了,叫CupDnn好了。所以,今天先写着试下遍历一个计算图吧。
先随便构造一个简单的计算图,如下:

关于这个图的说明:
假设每一个Unit都执行一些计算,把计算结果推送给与他相连接的Unit,下一个Unit对输入再做计算。也就是说,如果有一个Unit,有三个输入,那么为了完成计算,它必须等待三个输入都将结果输出给它,它才可以执行它的计算。比如说,fifth需要second,third,forth三个图元的输出结果,它必须等这三个图源都完成计算,并将结果递送给它,它才可以进行计算。这里的计算就是打印图元的名字。

遍历分为深度优先和广度优先,深度优先用递归实现,广度优先用任务队列实现。
代码如下:

//define graph unit
typedef struct graph_unit
    string name;
    bool visit;
    int inputSize;
    int outputSize;
    vector<struct graph_unit *> *outputs;
    vector<struct graph_unit *> *inputs;
Unit,*PUnit;

PUnit createGraph()

    PUnit input = new Unit;
    input->name = "intput";
    input->inputSize = 0;
    input->outputSize = 3;
    input->visit = false;
    input->inputs = nullptr;
    input->outputs = new vector<PUnit>;
    
    PUnit first = new Unit;
    first->name = "first";
    first->inputSize = 1;
    first->outputSize = 1;
    first->visit = false;
    first->inputs = new vector<PUnit>;
    first->outputs = new vector<PUnit>;
    
    PUnit second = new Unit;
    second->name = "second";
    second->inputSize = 1;
    second->outputSize = 1;
    second->visit = false;
    second->inputs = new vector<PUnit>;
    second->outputs = new vector<PUnit>;
    
    PUnit third = new Unit;
    third->name = "third";
    third->inputSize = 1;
    third->outputSize = 1;
    third->visit = false;
    third->inputs = new vector<PUnit>;
    third->outputs = new vector<PUnit>;
    
    PUnit forth = new Unit;
    forth->name = "forth";
    forth->inputSize = 1;
    forth->outputSize = 1;
    forth->visit = false;
    forth->inputs = new vector<PUnit>;
    forth->outputs = new vector<PUnit>;
    
    PUnit fifth = new Unit;
    fifth->name = "fifth";
    fifth->inputSize = 3;
    fifth->outputSize = 0;
    fifth->visit = false;
    fifth->inputs = new vector<PUnit>;
    fifth->outputs = nullptr;
    
    input->outputs->push_back(first);
    input->outputs->push_back(second);
    input->outputs->push_back(third);
    
    first->outputs->push_back(forth);
    
    second->outputs->push_back(fifth);
    third->outputs->push_back(fifth);
    forth->outputs->push_back(fifth);
    
    return input;


//广度优先遍历图,释放所有的资源
void broadScanDestroyGraph(PUnit input)

    queue<PUnit> preProcessQueue;
    preProcessQueue.push(input);
    while (!preProcessQueue.empty()) 
        PUnit tmp = preProcessQueue.front();
        preProcessQueue.pop();
        if(!tmp)continue;
        if(tmp->outputs)
            for(int i=0;i<tmp->outputs->size();i++)
                PUnit p = tmp->outputs->at(i);
                p->inputSize--;
                if(!p->inputSize)
                    preProcessQueue.push(p);
                
            
        
        cout<<"destory: "<<tmp->name<<endl;
        if(tmp->inputs)
            tmp->inputs->clear();
            delete tmp->inputs;
        
        if(tmp->outputs)
            tmp->outputs->clear();
            delete tmp->outputs;
        
        if(tmp)delete tmp;
        tmp = nullptr;
    

//广度优先遍历图,执行计算
//采用任务队列
void broadScanGraph(PUnit input)

    queue<PUnit> preProcessQueue;
    preProcessQueue.push(input);
    while (!preProcessQueue.empty()) 
        PUnit tmp = preProcessQueue.front();
        preProcessQueue.pop();
        if(!tmp)continue;
        tmp->visit = true;
        cout<<tmp->name<<endl;
        if(!tmp->outputSize)continue;
        for(int i=0;i<tmp->outputs->size();i++)
            PUnit p = tmp->outputs->at(i);
            p->inputs->push_back(tmp);
            if(p->inputSize == p->inputs->size() && !p->visit)
                preProcessQueue.push(p);
            
        
    

//深度优先遍历图,释放资源
void deepScanDestoryGraph(PUnit input)

    if(!input->outputSize)return;
    for(int i=0;i<input->outputs->size();i++)
        PUnit p = input->outputs->at(i);
        p->inputSize--;
        if(!p->inputSize)
            deepScanDestoryGraph(p);
            cout<<"destory: "<<p->name<<endl;
            if(p->inputs)
                p->inputs->clear();
                delete p->inputs;
            
            if(p->outputs)
                p->outputs->clear();
                delete p->outputs;
            
            if(p)delete p;
            p = nullptr;
        
    

//深度优先遍历图,执行计算
//采用递归
void deepScanGraph(PUnit input)

    if(!input->outputSize)return;
    for(int i=0;i<input->outputs->size();i++)
        PUnit p = input->outputs->at(i);
        p->inputs->push_back(input);
        if(p->inputSize == p->inputs->size() && !p->visit)
            cout<<p->name<<endl;
            p->visit = true;
            deepScanGraph(p);
        
    


int main(int argc, const char * argv[]) 
    // insert code here...
    cout<<"-------broad scan grapp--------"<<endl;
    PUnit input = createGraph();
    broadScanGraph(input);
    broadScanDestroyGraph(input);
    
    cout<<"-------deep scan grapp--------"<<endl;
    input = createGraph();
    deepScanGraph(input);
    deepScanDestoryGraph(input);
    return 0;


结果如下:

-------broad scan grapp--------
intput
first
second
third
forth
fifth
destory: intput
destory: first
destory: second
destory: third
destory: forth
destory: fifth
-------deep scan grapp--------
first
forth
second
third
fifth
destory: forth
destory: first
destory: second
destory: fifth
destory: third
Program ended with exit code: 0

再增加一个图元,构造更加复杂一点看看对不对:

构造代码如下:

PUnit createGraph2()

    PUnit input = new Unit;
    input->name = "intput";
    input->inputSize = 0;
    input->outputSize = 3;
    input->visit = false;
    input->inputs = nullptr;
    input->outputs = new vector<PUnit>;
    
    PUnit first = new Unit;
    first->name = "first";
    first->inputSize = 1;
    first->outputSize = 1;
    first->visit = false;
    first->inputs = new vector<PUnit>;
    first->outputs = new vector<PUnit>;
    
    PUnit second = new Unit;
    second->name = "second";
    second->inputSize = 1;
    second->outputSize = 1;
    second->visit = false;
    second->inputs = new vector<PUnit>;
    second->outputs = new vector<PUnit>;
    
    PUnit third = new Unit;
    third->name = "third";
    third->inputSize = 1;
    third->outputSize = 1;
    third->visit = false;
    third->inputs = new vector<PUnit>;
    third->outputs = new vector<PUnit>;
    
    PUnit forth = new Unit;
    forth->name = "forth";
    forth->inputSize = 1;
    forth->outputSize = 1;
    forth->visit = false;
    forth->inputs = new vector<PUnit>;
    forth->outputs = new vector<PUnit>;
    
    PUnit fifth = new Unit;
    fifth->name = "fifth";
    fifth->inputSize = 3;
    fifth->outputSize = 1;
    fifth->visit = false;
    fifth->inputs = new vector<PUnit>;
    fifth->outputs = new vector<PUnit>;

    PUnit sixth = new Unit;
    sixth->name = "sixth";
    sixth->inputSize = 2;
    sixth->outputSize = 0;
    sixth->visit = false;
    sixth->inputs = new vector<PUnit>;
    sixth->outputs = nullptr;

    
    input->outputs->push_back(first);
    input->outputs->push_back(second);
    input->outputs->push_back(third);
    
    first->outputs->push_back(forth);
    
    second->outputs->push_back(fifth);
    third->outputs->push_back(fifth);
    forth->outputs->push_back(fifth);
    
    fifth->outputs->push_back(sixth);
    forth->outputs->push_back(sixth);
    
    return input;

执行结果如下:

-------broad scan grapp--------
intput
first
second
third
forth
fifth
sixth
destory: intput
destory: first
destory: second
destory: third
destory: forth
destory: fifth
destory: sixth
-------deep scan grapp--------
first
forth
second
third
fifth
sixth
destory: forth
destory: first
destory: second
destory: sixth
destory: fifth
destory: third
Program ended with exit code: 0

可见还是按预期运行的。

以上是关于计算图(graph)的遍历的主要内容,如果未能解决你的问题,请参考以下文章

Graph 感悟

785. 判断二分图——本质上就是图的遍历 dfs或者bfs

Detect Cycle In Directed/Undirected Graph

ROS学习笔记-rqt_graph生成ROS系统中计算图

算法笔记:树堆和图

tensorflow-新计算图