PGL图学习之图神经网络GraphSAGEGIN图采样算法[系列七]
Posted 汀、
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PGL图学习之图神经网络GraphSAGEGIN图采样算法[系列七]相关的知识,希望对你有一定的参考价值。
0. PGL图学习之图神经网络GraphSAGE、GIN图采样算法[系列七]
本项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5061984?contributionType=1
相关项目参考:更多资料见主页
关于图计算&图学习的基础知识概览:前置知识点学习(PGL)[系列一] https://aistudio.baidu.com/aistudio/projectdetail/4982973?contributionType=1
图机器学习(GML)&图神经网络(GNN)原理和代码实现(前置学习系列二):https://aistudio.baidu.com/aistudio/projectdetail/4990947?contributionType=1
在图神经网络中,使用的数据集可能是亿量级的数据,而由于GPU/CPU资源有限无法一次性全图送入计算资源,需要借鉴深度学习中的mini-batch思想。
传统的深度学习mini-batch训练每个batch的样本之间无依赖,多层样本计算量固定;而在图神经网络中,每个batch中的节点之间互相依赖,在计算多层时会导致计算量爆炸,因此引入了图采样的概念。
GraphSAGE也是图嵌入算法中的一种。在论文Inductive Representation Learning on Large Graphs 在大图上的归纳表示学习中提出。github链接和官方介绍链接。
与node2vec相比较而言,node2vec是在图的节点级别上进行嵌入,GraphSAGE则是在整个图的级别上进行嵌入。之前的网络表示学习的transductive,难以从而提出了一个inductive的GraphSAGE算法。GraphSAGE同时利用节点特征信息和结构信息得到Graph Embedding的映射,相比之前的方法,之前都是保存了映射后的结果,而GraphSAGE保存了生成embedding的映射,可扩展性更强,对于节点分类和链接预测问题的表现也比较突出。
0.1提出背景
现存的方法需要图中所有的顶点在训练embedding的时候都出现;这些前人的方法本质上是transductive,不能自然地泛化到未见过的顶点。文中提出了GraphSAGE,是一个inductive的框架,可以利用顶点特征信息(比如文本属性)来高效地为没有见过的顶点生成embedding。GraphSAGE是为了学习一种节点表示方法,即如何通过从一个顶点的局部邻居采样并聚合顶点特征,而不是为每个顶点训练单独的embedding。
这个算法在三个inductive顶点分类benchmark上超越了那些很强的baseline。文中基于citation和Reddit帖子数据的信息图中对未见过的顶点分类,实验表明使用一个PPI(protein-protein interactions)多图数据集,算法可以泛化到完全未见过的图上。
0.2 回顾GCN及其问题
在大型图中,节点的低维向量embedding被证明了作为各种各样的预测和图分析任务的特征输入是非常有用的。顶点embedding最基本的基本思想是使用降维技术从高维信息中提炼一个顶点的邻居信息,存到低维向量中。这些顶点嵌入之后会作为后续的机器学习系统的输入,解决像顶点分类、聚类、链接预测这样的问题。
- GCN虽然能提取图中顶点的embedding,但是存在一些问题:
- GCN的基本思想: 把一个节点在图中的高纬度邻接信息降维到一个低维的向量表示。
- GCN的优点: 可以捕捉graph的全局信息,从而很好地表示node的特征。
- GCN的缺点: Transductive learning的方式,需要把所有节点都参与训练才能得到node embedding,无法快速得到新node的embedding。
1.图采样算法
1.1 GraphSage: Representation Learning on Large Graphs
图采样算法:顾名思义,图采样算法就是在一张图中进行采样得到一个子图,这里的采样并不是随机采样,而是采取一些策略。典型的图采样算法包括GraphSAGE、PinSAGE等。
文章码源链接:
https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
https://github.com/williamleif/GraphSAGE
前面 GCN 讲解的文章中,我使用的图节点个数非常少,然而在实际问题中,一张图可能节点非常多,因此就没有办法一次性把整张图送入计算资源,所以我们应该使用一种有效的采样算法,从全图中采样出一个子图 ,这样就可以进行训练了。
GraphSAGE与GCN对比:
既然新增的节点,一定会改变原有节点的表示,那么为什么一定要得到每个节点的一个固定的表示呢?何不直接学习一种节点的表示方法。去学习一个节点的信息是怎么通过其邻居节点的特征聚合而来的。 学习到了这样的“聚合函数”,而我们本身就已知各个节点的特征和邻居关系,我们就可以很方便地得到一个新节点的表示了。
GCN等transductive的方法,学到的是每个节点的一个唯一确定的embedding; 而GraphSAGE方法学到的node embedding,是根据node的邻居关系的变化而变化的,也就是说,即使是旧的node,如果建立了一些新的link,那么其对应的embedding也会变化,而且也很方便地学到。
在了解图采样算法前,我们至少应该保证采样后的子图是连通的。例如上图图中,左边采样的子图就是连通的,右边的子图不是连通的。
GraphSAGE的核心:GraphSAGE不是试图学习一个图上所有node的embedding,而是学习一个为每个node产生embedding的映射。 GraphSage框架中包含两个很重要的操作:Sample采样和Aggregate聚合。这也是其名字GraphSage(Graph SAmple and aggreGatE)的由来。GraphSAGE 主要分两步:采样、聚合。GraphSAGE的采样方式是邻居采样,邻居采样的意思是在某个节点的邻居节点中选择几个节点作为原节点的一阶邻居,之后对在新采样的节点的邻居中继续选择节点作为原节点的二阶节点,以此类推。
文中不是对每个顶点都训练一个单独的embeddding向量,而是训练了一组aggregator functions,这些函数学习如何从一个顶点的局部邻居聚合特征信息(见图1)。每个聚合函数从一个顶点的不同的hops或者说不同的搜索深度聚合信息。测试或是推断的时候,使用训练好的系统,通过学习到的聚合函数来对完全未见过的顶点生成embedding。
GraphSAGE 是Graph SAmple and aggreGatE的缩写,其运行流程如上图所示,可以分为三个步骤:
- 对图中每个顶点邻居顶点进行采样,因为每个节点的度是不一致的,为了计算高效, 为每个节点采样固定数量的邻居
- 根据聚合函数聚合邻居顶点蕴含的信息
- 得到图中各顶点的向量表示供下游任务使用
邻居采样的优点:
- 极大减少计算量
- 允许泛化到新连接关系,个人理解类似dropout的思想,能增强模型的泛化能力
采样的阶段首先选取一个点,然后随机选取这个点的一阶邻居,再以这些邻居为起点随机选择它们的一阶邻居。例如下图中,我们要预测 0 号节点,因此首先随机选择 0 号节点的一阶邻居 2、4、5,然后随机选择 2 号节点的一阶邻居 8、9;4 号节点的一阶邻居 11、12;5 号节点的一阶邻居 13、15
聚合具体来说就是直接将子图从全图中抽离出来,从最边缘的节点开始,一层一层向里更新节点
上图展示了邻居采样的优点,极大减少训练计算量这个是毋庸置疑的,泛化能力增强这个可能不太好理解,因为原本要更新一个节点需要它周围的所有邻居,而通过邻居采样之后,每个节点就不是由所有的邻居来更新它,而是部分邻居节点,所以具有比较强的泛化能力。
1.1.1 论文角度看GraphSage
聚合函数的选取
在图中顶点的邻居是无序的,所以希望构造出的聚合函数是对称的(即也就是对它输入的各种排列,函数的输出结果不变),同时具有较高的表达能力。 聚合函数的对称性(symmetry property)确保了神经网络模型可以被训练且可以应用于任意顺序的顶点邻居特征集合上。
**a. Mean aggregator **:
mean aggregator将目标顶点和邻居顶点的第
k
−
1
k−1
k−1层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第
k
k
k层表示向量。
卷积聚合器Convolutional aggregator:
文中用下面的式子替换算法1中的4行和5行得到GCN的inductive变形:
h v k ← σ ( W ⋅ MEAN ( h v k − 1 ∪ h u k − 1 , ∀ u ∈ N ( v ) ) ) \\mathbfh_v^k \\leftarrow \\sigma\\left(\\mathbfW \\cdot \\operatornameMEAN\\left(\\left\\\\mathbfh_v^k-1\\right\\ \\cup\\left\\\\mathbfh_u^k-1, \\forall u \\in \\mathcalN(v)\\right\\\\right)\\right) hvk←σ(W⋅MEAN(hvk−1∪huk−1,∀u∈N(v)))
原始算法1中的第4,5行是
h N ( v ) k ← AGGREGATE k ( h u k − 1 , ∀ u ∈ N ( v ) ) h v k ← σ ( W k ⋅ CONCAT ( h v k − 1 , h N ( v ) k ) ) \\beginarrayl \\mathbfh_\\mathcalN(v)^k \\leftarrow \\operatornameAGGREGATE_k\\left(\\left\\\\mathbfh_u^k-1, \\forall u \\in \\mathcalN(v)\\right\\\\right) \\\\ \\mathbfh_v^k \\leftarrow \\sigma\\left(\\mathbfW^k \\cdot \\operatornameCONCAT\\left(\\mathbfh_v^k-1, \\mathbfh_\\mathcalN(v)^k\\right)\\right) \\endarray hN(v)k←AGGREGATEk(huk−1,∀u∈N(v))hvk←σ(Wk⋅CONCAT(hvk−1,hN(v)k))
论文提出的均值聚合器Mean aggregator:
h v k ← σ ( W ⋅ MEAN ( h v k − 1 ∪ h u k − 1 , ∀ u ∈ N ( v ) ) ) h v k ← σ ( W k ⋅ CONCAT ( h v k − 1 , h N ( v ) k ) ) \\beginarrayl \\mathbfh_v^k \\leftarrow \\sigma\\left(\\mathbfW \\cdot \\operatornameMEAN\\left(\\left\\\\mathbfh_v^k-1\\right\\ \\cup\\left\\\\mathbfh_u^k-1, \\forall u \\in \\mathcalN(v)\\right\\\\right)\\right) \\\\ \\mathbfh_v^k \\leftarrow \\sigma\\left(\\mathbfW^k \\cdot \\operatornameCONCAT\\left(\\mathbfh_v^k-1, \\mathbfh_\\mathcalN(v)^k\\right)\\right) \\endarray hvk←σ(W⋅MEAN(hvk−1∪huk−1,∀u∈N(v)))hvk←σ(Wk⋅CONCAT(hvk−1,hN(v)k))
- 均值聚合近似等价在transducttive GCN框架中的卷积传播规则
- 这个修改后的基于均值的聚合器是convolutional的。但是这个卷积聚合器和文中的其他聚合器的重要不同在于它没有算法1中第5行的CONCAT操作——卷积聚合器没有将顶点前一层的表示 h v k − 1 \\mathbfh^k-1_v hvk−1聚合的邻居向量 h N ( v ) k \\mathbfh^k_\\mathcalN(v) hN(v)k拼接起来
- 拼接操作可以看作一个是GraphSAGE算法在不同的搜索深度或层之间的简单的skip connection[Identity mappings in deep residual networks]的形式,它使得模型的表征性能获得了巨大的提升
- 举个简单例子,比如一个节点的3个邻居的embedding分别为[1,2,3,4],[2,3,4,5],[3,4,5,6]按照每一维分别求均值就得到了聚合后的邻居embedding为[2,3,4,5]
b. LSTM aggregator
文中也测试了一个基于LSTM的复杂的聚合器[Long short-term memory]。和均值聚合器相比,LSTMs有更强的表达能力。但是,LSTMs不是对称的(symmetric),也就是说不具有排列不变性(permutation invariant),因为它们以一个序列的方式处理输入。因此,需要先对邻居节点随机顺序,然后将邻居序列的embedding作为LSTM的输入。
- 排列不变性(permutation invariance):指输入的顺序改变不会影响输出的值。
c. Pooling aggregator
pooling聚合器,它既是对称的,又是可训练的。Pooling aggregator 先对目标顶点的邻居顶点的embedding向量进行一次非线性变换,之后进行一次pooling操作(max pooling or mean pooling),将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第k层表示向量。
一个element-wise max pooling操作应用在邻居集合上来聚合信息:
h N ( v ) k = A G G R E G A T E k pool = max ( σ ( W pool h u k − 1 + b ) , ∀ u ∈ N ( v ) ) h v k ← σ ( W k ⋅ CONCAT ( h v k − 1 , h N ( v ) k ) ) \\beginaligned \\mathbfh_\\mathcalN(v)^k=& \\mathrmAGGREGATE_k^\\text pool=\\max \\left(\\left\\\\sigma\\left(\\mathbfW_\\text pool \\mathbfh_u^k-1+\\mathbfb\\right), \\forall u \\in \\mathcalN(v)\\right\\\\right) \\\\ &\\mathbfh_v^k \\leftarrow \\sigma\\left(\\mathbfW^k \\cdot \\operatornameCONCAT\\left(\\mathbfh_v^k-1, \\mathbfh_\\mathcalN(v)^k\\right)\\right) \\endaligned hN(v)k=AGGREGATEkpool=max(σ(Wpoolhuk−1+b),∀u∈N(v))hv以上是关于PGL图学习之图神经网络GraphSAGEGIN图采样算法[系列七]的主要内容,如果未能解决你的问题,请参考以下文章
PGL图学习之图神经网络ERNIESageUniMP进阶模型[系列八]
PGL图学习之基于UniMP算法的论文引用网络节点分类任务[系列九]
6.Paddle Graph Learning (PGL)图学习之图游走类模型[系列四]