GraphSAGE 代码解析 - minibatch.py
Posted 认真积累每一天
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GraphSAGE 代码解析 - minibatch.py相关的知识,希望对你有一定的参考价值。
class EdgeMinibatchIterator
""" This minibatch iterator iterates over batches of sampled edges or random pairs of co-occuring edges. G -- networkx graph id2idx -- dict mapping node ids to index in feature tensor placeholders -- tensorflow placeholders object context_pairs -- if not none, then a list of co-occuring node pairs (from random walks) batch_size -- size of the minibatches max_degree -- maximum size of the downsampled adjacency lists n2v_retrain -- signals that the iterator is being used to add new embeddings to a n2v model fixed_n2v -- signals that the iterator is being used to retrain n2v with only existing nodes as context """
def __init__(self, G, id2idx, placeholders, context_pairs=None, batch_size=100, max_degree=25,
n2v_retrain=False, fixed_n2v=False, **kwargs) 中具体介绍以下:
1 self.nodes = np.random.permutation(G.nodes())
2 # 函数shuffle与permutation都是对原来的数组进行重新洗牌,即随机打乱原来的元素顺序
3 # shuffle直接在原来的数组上进行操作,改变原来数组的顺序,无返回值
4 # permutation不直接在原来的数组上进行操作,而是返回一个新的打乱顺序的数组,并不改变原来的数组。
1 self.adj, self.deg = self.construct_adj()
这里重点看construct_adj()函数。
1 def construct_adj(self): 2 adj = len(self.id2idx) * 3 np.ones((len(self.id2idx) + 1, self.max_degree)) 4 # 该矩阵记录训练数据中各节点的邻居节点的编号 5 # 采样只取max_degree个邻居节点,采样方法见下 6 # 同样进行了行数加一操作 7 8 deg = np.zeros((len(self.id2idx),)) 9 # 该矩阵记录了每个节点的度数 10 11 for nodeid in self.G.nodes(): 12 if self.G.node[nodeid][‘test‘] or self.G.node[nodeid][‘val‘]: 13 continue 14 neighbors = np.array([self.id2idx[neighbor] 15 for neighbor in self.G.neighbors(nodeid) 16 if (not self.G[nodeid][neighbor][‘train_removed‘])]) 17 # Graph.neighbors() Return a list of the nodes connected to the node n. 18 # 在选取邻居节点时进行了筛选,对于G.neighbors(nodeid) 点node的邻居, 19 # 只取该node与neighbor相连的边的train_removed = False的neighbor 20 # 也就是只取不是val, test的节点。 21 # neighbors得到了邻居节点编号数列。 22 23 deg[self.id2idx[nodeid]] = len(neighbors) 24 # deg各位取值为该位对应nodeid的节点的度数, 25 # 也即经过上面筛选后得到的邻居数 26 27 if len(neighbors) == 0: 28 continue 29 if len(neighbors) > self.max_degree: 30 neighbors = np.random.choice( 31 neighbors, self.max_degree, replace=False) 32 # range: neighbors; size = max_degree; replace: replace the origin matrix or not 33 # np.random.choice为选取size大小的数列 34 35 elif len(neighbors) < self.max_degree: 36 neighbors = np.random.choice( 37 neighbors, self.max_degree, replace=True) 38 # 经过choice随机选取,得到了固定大小max_degree = 25的直接相连的邻居数列 39 40 adj[self.id2idx[nodeid], :] = neighbors 41 # 把该node的邻居数列,赋值给adj矩阵中对应nodeid位的向量。 42 return adj, deg
在construct_test_adj() 函数中,与上不同之处在于,可以直接得到邻居而无需根据val/test/train_removed筛选.
1 neighbors = np.array([self.id2idx[neighbor] 2 for neighbor in self.G.neighbors(nodeid)])
以上是关于GraphSAGE 代码解析 - minibatch.py的主要内容,如果未能解决你的问题,请参考以下文章