计算图(graph)的遍历
Posted 阳光玻璃杯
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了计算图(graph)的遍历相关的知识,希望对你有一定的参考价值。
很久没有写博客了,忙并不是借口,懒才是理由。
一直想重构CupCnn,写成一个通用的计算图,能随意搭建各种神经网络(CupCnn只能搭建有一个链路的有序的神经网络),然后把名字也改了,叫CupDnn好了。所以,今天先写着试下遍历一个计算图吧。
先随便构造一个简单的计算图,如下:
关于这个图的说明:
假设每一个Unit都执行一些计算,把计算结果推送给与他相连接的Unit,下一个Unit对输入再做计算。也就是说,如果有一个Unit,有三个输入,那么为了完成计算,它必须等待三个输入都将结果输出给它,它才可以执行它的计算。比如说,fifth需要second,third,forth三个图元的输出结果,它必须等这三个图源都完成计算,并将结果递送给它,它才可以进行计算。这里的计算就是打印图元的名字。
遍历分为深度优先和广度优先,深度优先用递归实现,广度优先用任务队列实现。
代码如下:
//define graph unit
typedef struct graph_unit
string name;
bool visit;
int inputSize;
int outputSize;
vector<struct graph_unit *> *outputs;
vector<struct graph_unit *> *inputs;
Unit,*PUnit;
PUnit createGraph()
PUnit input = new Unit;
input->name = "intput";
input->inputSize = 0;
input->outputSize = 3;
input->visit = false;
input->inputs = nullptr;
input->outputs = new vector<PUnit>;
PUnit first = new Unit;
first->name = "first";
first->inputSize = 1;
first->outputSize = 1;
first->visit = false;
first->inputs = new vector<PUnit>;
first->outputs = new vector<PUnit>;
PUnit second = new Unit;
second->name = "second";
second->inputSize = 1;
second->outputSize = 1;
second->visit = false;
second->inputs = new vector<PUnit>;
second->outputs = new vector<PUnit>;
PUnit third = new Unit;
third->name = "third";
third->inputSize = 1;
third->outputSize = 1;
third->visit = false;
third->inputs = new vector<PUnit>;
third->outputs = new vector<PUnit>;
PUnit forth = new Unit;
forth->name = "forth";
forth->inputSize = 1;
forth->outputSize = 1;
forth->visit = false;
forth->inputs = new vector<PUnit>;
forth->outputs = new vector<PUnit>;
PUnit fifth = new Unit;
fifth->name = "fifth";
fifth->inputSize = 3;
fifth->outputSize = 0;
fifth->visit = false;
fifth->inputs = new vector<PUnit>;
fifth->outputs = nullptr;
input->outputs->push_back(first);
input->outputs->push_back(second);
input->outputs->push_back(third);
first->outputs->push_back(forth);
second->outputs->push_back(fifth);
third->outputs->push_back(fifth);
forth->outputs->push_back(fifth);
return input;
//广度优先遍历图,释放所有的资源
void broadScanDestroyGraph(PUnit input)
queue<PUnit> preProcessQueue;
preProcessQueue.push(input);
while (!preProcessQueue.empty())
PUnit tmp = preProcessQueue.front();
preProcessQueue.pop();
if(!tmp)continue;
if(tmp->outputs)
for(int i=0;i<tmp->outputs->size();i++)
PUnit p = tmp->outputs->at(i);
p->inputSize--;
if(!p->inputSize)
preProcessQueue.push(p);
cout<<"destory: "<<tmp->name<<endl;
if(tmp->inputs)
tmp->inputs->clear();
delete tmp->inputs;
if(tmp->outputs)
tmp->outputs->clear();
delete tmp->outputs;
if(tmp)delete tmp;
tmp = nullptr;
//广度优先遍历图,执行计算
//采用任务队列
void broadScanGraph(PUnit input)
queue<PUnit> preProcessQueue;
preProcessQueue.push(input);
while (!preProcessQueue.empty())
PUnit tmp = preProcessQueue.front();
preProcessQueue.pop();
if(!tmp)continue;
tmp->visit = true;
cout<<tmp->name<<endl;
if(!tmp->outputSize)continue;
for(int i=0;i<tmp->outputs->size();i++)
PUnit p = tmp->outputs->at(i);
p->inputs->push_back(tmp);
if(p->inputSize == p->inputs->size() && !p->visit)
preProcessQueue.push(p);
//深度优先遍历图,释放资源
void deepScanDestoryGraph(PUnit input)
if(!input->outputSize)return;
for(int i=0;i<input->outputs->size();i++)
PUnit p = input->outputs->at(i);
p->inputSize--;
if(!p->inputSize)
deepScanDestoryGraph(p);
cout<<"destory: "<<p->name<<endl;
if(p->inputs)
p->inputs->clear();
delete p->inputs;
if(p->outputs)
p->outputs->clear();
delete p->outputs;
if(p)delete p;
p = nullptr;
//深度优先遍历图,执行计算
//采用递归
void deepScanGraph(PUnit input)
if(!input->outputSize)return;
for(int i=0;i<input->outputs->size();i++)
PUnit p = input->outputs->at(i);
p->inputs->push_back(input);
if(p->inputSize == p->inputs->size() && !p->visit)
cout<<p->name<<endl;
p->visit = true;
deepScanGraph(p);
int main(int argc, const char * argv[])
// insert code here...
cout<<"-------broad scan grapp--------"<<endl;
PUnit input = createGraph();
broadScanGraph(input);
broadScanDestroyGraph(input);
cout<<"-------deep scan grapp--------"<<endl;
input = createGraph();
deepScanGraph(input);
deepScanDestoryGraph(input);
return 0;
结果如下:
-------broad scan grapp--------
intput
first
second
third
forth
fifth
destory: intput
destory: first
destory: second
destory: third
destory: forth
destory: fifth
-------deep scan grapp--------
first
forth
second
third
fifth
destory: forth
destory: first
destory: second
destory: fifth
destory: third
Program ended with exit code: 0
再增加一个图元,构造更加复杂一点看看对不对:
构造代码如下:
PUnit createGraph2()
PUnit input = new Unit;
input->name = "intput";
input->inputSize = 0;
input->outputSize = 3;
input->visit = false;
input->inputs = nullptr;
input->outputs = new vector<PUnit>;
PUnit first = new Unit;
first->name = "first";
first->inputSize = 1;
first->outputSize = 1;
first->visit = false;
first->inputs = new vector<PUnit>;
first->outputs = new vector<PUnit>;
PUnit second = new Unit;
second->name = "second";
second->inputSize = 1;
second->outputSize = 1;
second->visit = false;
second->inputs = new vector<PUnit>;
second->outputs = new vector<PUnit>;
PUnit third = new Unit;
third->name = "third";
third->inputSize = 1;
third->outputSize = 1;
third->visit = false;
third->inputs = new vector<PUnit>;
third->outputs = new vector<PUnit>;
PUnit forth = new Unit;
forth->name = "forth";
forth->inputSize = 1;
forth->outputSize = 1;
forth->visit = false;
forth->inputs = new vector<PUnit>;
forth->outputs = new vector<PUnit>;
PUnit fifth = new Unit;
fifth->name = "fifth";
fifth->inputSize = 3;
fifth->outputSize = 1;
fifth->visit = false;
fifth->inputs = new vector<PUnit>;
fifth->outputs = new vector<PUnit>;
PUnit sixth = new Unit;
sixth->name = "sixth";
sixth->inputSize = 2;
sixth->outputSize = 0;
sixth->visit = false;
sixth->inputs = new vector<PUnit>;
sixth->outputs = nullptr;
input->outputs->push_back(first);
input->outputs->push_back(second);
input->outputs->push_back(third);
first->outputs->push_back(forth);
second->outputs->push_back(fifth);
third->outputs->push_back(fifth);
forth->outputs->push_back(fifth);
fifth->outputs->push_back(sixth);
forth->outputs->push_back(sixth);
return input;
执行结果如下:
-------broad scan grapp--------
intput
first
second
third
forth
fifth
sixth
destory: intput
destory: first
destory: second
destory: third
destory: forth
destory: fifth
destory: sixth
-------deep scan grapp--------
first
forth
second
third
fifth
sixth
destory: forth
destory: first
destory: second
destory: sixth
destory: fifth
destory: third
Program ended with exit code: 0
可见还是按预期运行的。
以上是关于计算图(graph)的遍历的主要内容,如果未能解决你的问题,请参考以下文章
785. 判断二分图——本质上就是图的遍历 dfs或者bfs